Delete mllm
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- mllm/flamingo/__init__.py +0 -48
- mllm/flamingo/config.json +0 -21
- mllm/flamingo/configuration_flamingo.py +0 -100
- mllm/flamingo/converting_flamingo_to_bf16.py +0 -30
- mllm/flamingo/converting_flamingo_to_hf.py +0 -61
- mllm/flamingo/converting_flamingo_to_lora.py +0 -68
- mllm/flamingo/falcon/__init__.py +0 -0
- mllm/flamingo/falcon/__pycache__/__init__.cpython-39.pyc +0 -0
- mllm/flamingo/falcon/__pycache__/configuration_RW.cpython-39.pyc +0 -0
- mllm/flamingo/falcon/__pycache__/modelling_RW.cpython-39.pyc +0 -0
- mllm/flamingo/falcon/configuration_RW.py +0 -79
- mllm/flamingo/falcon/modelling_RW.py +0 -1064
- mllm/flamingo/flamingo-falcon-7B.json +0 -112
- mllm/flamingo/flamingo-llama2-chat-13B.json +0 -114
- mllm/flamingo/flamingo-llama2-chat-7B.json +0 -115
- mllm/flamingo/flamingo-mpt-1B-redpajama.json +0 -131
- mllm/flamingo/flamingo-mpt-30B-bf16.json +0 -195
- mllm/flamingo/flamingo-mpt-30B.json +0 -195
- mllm/flamingo/flamingo-mpt-7B.json +0 -195
- mllm/flamingo/flamingo-vicuna-33B-v1.3.json +0 -111
- mllm/flamingo/flamingo-vicuna-7B-v1.3.json +0 -111
- mllm/flamingo/injecting_falcon_into_flamingo.py +0 -49
- mllm/flamingo/injecting_llama2_into_flamingo.py +0 -95
- mllm/flamingo/injecting_mpt-1B-redpajama_into_flamingo.py +0 -97
- mllm/flamingo/injecting_mpt_into_flamingo.py +0 -109
- mllm/flamingo/injecting_vicuna_into_flamingo.py +0 -100
- mllm/flamingo/modeling_flamingo.py +0 -966
- mllm/flamingo/mpt/__init__.py +0 -0
- mllm/flamingo/mpt/__pycache__/__init__.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/attention.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/blocks.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/configuration_mpt.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/custom_embedding.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/flash_attn_triton.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/modeling_mpt.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/norm.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/__pycache__/param_init_fns.cpython-39.pyc +0 -0
- mllm/flamingo/mpt/adapt_tokenizer.py +0 -44
- mllm/flamingo/mpt/attention.py +0 -450
- mllm/flamingo/mpt/blocks.py +0 -82
- mllm/flamingo/mpt/configuration_mpt.py +0 -161
- mllm/flamingo/mpt/custom_embedding.py +0 -11
- mllm/flamingo/mpt/flash_attn_triton.py +0 -841
- mllm/flamingo/mpt/hf_prefixlm_converter.py +0 -575
- mllm/flamingo/mpt/meta_init_context.py +0 -98
- mllm/flamingo/mpt/modeling_mpt.py +0 -496
- mllm/flamingo/mpt/norm.py +0 -60
- mllm/flamingo/mpt/param_init_fns.py +0 -369
- mllm/flamingo/mpt_redpajama/__init__.py +0 -0
- mllm/flamingo/mpt_redpajama/__pycache__/__init__.cpython-39.pyc +0 -0
mllm/flamingo/__init__.py
DELETED
|
@@ -1,48 +0,0 @@
|
|
| 1 |
-
from typing import TYPE_CHECKING
|
| 2 |
-
|
| 3 |
-
from transformers.utils import (
|
| 4 |
-
OptionalDependencyNotAvailable,
|
| 5 |
-
_LazyModule,
|
| 6 |
-
is_torch_available,
|
| 7 |
-
)
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
_import_structure = {
|
| 11 |
-
"configuration_flamingo": [
|
| 12 |
-
"FlamingoConfig",
|
| 13 |
-
],
|
| 14 |
-
}
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
if not is_torch_available():
|
| 18 |
-
raise OptionalDependencyNotAvailable()
|
| 19 |
-
except OptionalDependencyNotAvailable:
|
| 20 |
-
pass
|
| 21 |
-
else:
|
| 22 |
-
_import_structure["modeling_flamingo"] = [
|
| 23 |
-
"FlamingoModel",
|
| 24 |
-
"FlamingoPreTrainedModel",
|
| 25 |
-
"FlamingoForConditionalGeneration",
|
| 26 |
-
]
|
| 27 |
-
|
| 28 |
-
if TYPE_CHECKING:
|
| 29 |
-
from .configuration_flamingo import FlamingoConfig
|
| 30 |
-
|
| 31 |
-
# from .processing_flamingo import FlamingoProcessor
|
| 32 |
-
|
| 33 |
-
try:
|
| 34 |
-
if not is_torch_available():
|
| 35 |
-
raise OptionalDependencyNotAvailable()
|
| 36 |
-
except OptionalDependencyNotAvailable:
|
| 37 |
-
pass
|
| 38 |
-
else:
|
| 39 |
-
from .modeling_flamingo import (
|
| 40 |
-
FlamingoForConditionalGeneration,
|
| 41 |
-
FlamingoModel,
|
| 42 |
-
FlamingoPreTrainedModel,
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
else:
|
| 46 |
-
import sys
|
| 47 |
-
|
| 48 |
-
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/config.json
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"model_type": "flamingo",
|
| 3 |
-
"cross_attn_every_n_layers": 4,
|
| 4 |
-
"tie_word_embeddings": false,
|
| 5 |
-
"use_media_placement_augmentation": true,
|
| 6 |
-
"only_attend_previous": true,
|
| 7 |
-
"text_config": {
|
| 8 |
-
"_name_or_path": "luodian/llama-7b-hf",
|
| 9 |
-
"model_type": "llama"
|
| 10 |
-
},
|
| 11 |
-
"vision_config": {
|
| 12 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 13 |
-
"model_type": "clip_vision_model",
|
| 14 |
-
"hidden_size": 1024,
|
| 15 |
-
"intermediate_size": 4096,
|
| 16 |
-
"num_attention_heads": 16,
|
| 17 |
-
"num_hidden_layers": 24,
|
| 18 |
-
"image_size": 224,
|
| 19 |
-
"patch_size": 14
|
| 20 |
-
}
|
| 21 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/configuration_flamingo.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
|
| 3 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
-
from transformers.utils import logging
|
| 5 |
-
|
| 6 |
-
from transformers.models.auto import CONFIG_MAPPING
|
| 7 |
-
from transformers.models.clip import CLIPVisionConfig
|
| 8 |
-
import sys
|
| 9 |
-
|
| 10 |
-
from .falcon.configuration_RW import RWConfig
|
| 11 |
-
from .mpt.configuration_mpt import MPTConfig
|
| 12 |
-
from .mpt_redpajama.configuration_mosaic_gpt import MosaicGPTConfig
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
logger = logging.get_logger(__name__)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class FlamingoConfig(PretrainedConfig):
|
| 19 |
-
r"""
|
| 20 |
-
[`FlamingoConfig`] is the configuration class to store the configuration of a [`FlamingoForConditionalGeneration`]. It is
|
| 21 |
-
used to instantiate a Flamingo model according to the specified arguments, defining the vision model and language model configs. Instantiating a configuration with the defaults will yield a similar configuration to
|
| 22 |
-
that of the Flamingo architecture.
|
| 23 |
-
|
| 24 |
-
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 25 |
-
documentation from [`PretrainedConfig`] for more information.
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
vision_config (`dict`, *optional*):
|
| 29 |
-
Dictionary of configuration options used to initialize [`PretrainedConfig`].
|
| 30 |
-
text_config (`dict`, *optional*):
|
| 31 |
-
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
| 32 |
-
cross_attn_every_n_layers (`int`, *optional*, defaults to 4):
|
| 33 |
-
The number of cross-attention layers adding after each transformer layer.
|
| 34 |
-
|
| 35 |
-
kwargs (*optional*):
|
| 36 |
-
Dictionary of keyword arguments.
|
| 37 |
-
|
| 38 |
-
Example:
|
| 39 |
-
|
| 40 |
-
```python
|
| 41 |
-
>>> from transformers import (
|
| 42 |
-
... PretrainedConfig,
|
| 43 |
-
... OPTConfig,
|
| 44 |
-
... FlamingoConfig,
|
| 45 |
-
... FlamingoForConditionalGeneration,
|
| 46 |
-
... )
|
| 47 |
-
|
| 48 |
-
>>> # Initializing a FlamingoConfig with Salesforce/Flamingo-opt-2.7b style configuration
|
| 49 |
-
>>> configuration = FlamingoConfig()
|
| 50 |
-
|
| 51 |
-
>>> # Initializing a FlamingoForConditionalGeneration (with random weights) from the Salesforce/Flamingo-opt-2.7b style configuration
|
| 52 |
-
>>> model = FlamingoForConditionalGeneration(configuration)
|
| 53 |
-
```"""
|
| 54 |
-
model_type = "flamingo"
|
| 55 |
-
is_composition = True
|
| 56 |
-
|
| 57 |
-
def __init__(self, vision_config=None, text_config=None, cross_attn_every_n_layers: int = 4, use_media_placement_augmentation: bool = True, **kwargs):
|
| 58 |
-
super().__init__(**kwargs)
|
| 59 |
-
if vision_config is None:
|
| 60 |
-
vision_config = {}
|
| 61 |
-
logger.info("vision_config is None. initializing the vision config with default values.")
|
| 62 |
-
|
| 63 |
-
if text_config is None:
|
| 64 |
-
text_config = {}
|
| 65 |
-
logger.info("text_config is None. Initializing the text config with default values.")
|
| 66 |
-
|
| 67 |
-
self.vision_config = CLIPVisionConfig(**vision_config)
|
| 68 |
-
if "architectures" in text_config.keys() and text_config["architectures"] != None:
|
| 69 |
-
if text_config["architectures"][0] == "MPTForCausalLM":
|
| 70 |
-
self.text_config = MPTConfig(**text_config)
|
| 71 |
-
elif text_config["architectures"][0] == "MosaicGPT":
|
| 72 |
-
self.text_config = MosaicGPTConfig(**text_config)
|
| 73 |
-
elif text_config["architectures"][0] == "RWForCausalLM":
|
| 74 |
-
self.text_config = RWConfig(**text_config)
|
| 75 |
-
elif text_config["architectures"][0] == "LlamaForCausalLM":
|
| 76 |
-
self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
|
| 77 |
-
else:
|
| 78 |
-
import pdb
|
| 79 |
-
|
| 80 |
-
pdb.set_trace()
|
| 81 |
-
else:
|
| 82 |
-
self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
|
| 83 |
-
|
| 84 |
-
self.cross_attn_every_n_layers = cross_attn_every_n_layers
|
| 85 |
-
self.use_media_placement_augmentation = use_media_placement_augmentation
|
| 86 |
-
|
| 87 |
-
def to_dict(self):
|
| 88 |
-
"""
|
| 89 |
-
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
| 90 |
-
|
| 91 |
-
Returns:
|
| 92 |
-
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 93 |
-
"""
|
| 94 |
-
output = copy.deepcopy(self.__dict__)
|
| 95 |
-
output["vision_config"] = self.vision_config.to_dict()
|
| 96 |
-
output["text_config"] = self.text_config.to_dict()
|
| 97 |
-
output["model_type"] = self.__class__.model_type
|
| 98 |
-
output["cross_attn_every_n_layers"] = self.cross_attn_every_n_layers
|
| 99 |
-
output["use_media_placement_augmentation"] = self.use_media_placement_augmentation
|
| 100 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/converting_flamingo_to_bf16.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
|
| 6 |
-
from .configuration_flamingo import FlamingoConfig
|
| 7 |
-
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 8 |
-
|
| 9 |
-
parser = argparse.ArgumentParser(description="Load model with precision")
|
| 10 |
-
parser.add_argument("--load_bit", type=str, choices=["fp16", "bf16"], required=True, help="Choose either 'fp16' or 'bf16'")
|
| 11 |
-
parser.add_argument("--pretrained_model_path", type=str, default="/home/luodian/projects/checkpoints/flamingo-mpt-7B-instruct-init", required=True)
|
| 12 |
-
parser.add_argument("--saved_model_path", type=str, default="/home/luodian/projects/checkpoints/flamingo-mpt-7B-instruct-init", required=True)
|
| 13 |
-
args = parser.parse_args()
|
| 14 |
-
|
| 15 |
-
load_bit = args.load_bit
|
| 16 |
-
pretrained_model_path = args.pretrained_model_path
|
| 17 |
-
|
| 18 |
-
if load_bit == "fp16":
|
| 19 |
-
precision = {"torch_dtype": torch.float16}
|
| 20 |
-
elif load_bit == "bf16":
|
| 21 |
-
precision = {"torch_dtype": torch.bfloat16}
|
| 22 |
-
|
| 23 |
-
root_dir = os.environ["AZP"]
|
| 24 |
-
print(root_dir)
|
| 25 |
-
device_id = "cpu"
|
| 26 |
-
model = FlamingoForConditionalGeneration.from_pretrained(pretrained_model_path, device_map={"": device_id}, **precision)
|
| 27 |
-
|
| 28 |
-
# save model to same folder
|
| 29 |
-
checkpoint_path = pretrained_model_path + f"-{load_bit}"
|
| 30 |
-
model.save_pretrained(checkpoint_path, max_shard_size="10GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/converting_flamingo_to_hf.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
"""convert from otter pt to otter hf. Will remove after we use otter hf model to train.
|
| 2 |
-
"""
|
| 3 |
-
|
| 4 |
-
import re
|
| 5 |
-
import argparse
|
| 6 |
-
import os
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
|
| 11 |
-
|
| 12 |
-
import sys
|
| 13 |
-
from modeling_flamingo import FlamingoForConditionalGeneration
|
| 14 |
-
|
| 15 |
-
from configuration_flamingo import FlamingoConfig
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
@torch.no_grad()
|
| 19 |
-
def dump_hf_model(pretrained_model_path: str, old_ckpt_path: str, new_folder_path: str) -> None:
|
| 20 |
-
old_ckpt = torch.load(old_ckpt_path, map_location="cpu")
|
| 21 |
-
if old_ckpt.get("model_state_dict", None) is not None:
|
| 22 |
-
old_ckpt = old_ckpt["model_state_dict"]
|
| 23 |
-
new_ckpt = old_ckpt
|
| 24 |
-
folder_path = os.path.dirname(old_ckpt_path)
|
| 25 |
-
# config_path = os.path.join(folder_path, "config.json") if os.path.exists(os.path.join(folder_path, "config.json")) else "flamingo/config.json"
|
| 26 |
-
model = FlamingoForConditionalGeneration.from_pretrained(
|
| 27 |
-
args.pretrained_model_path,
|
| 28 |
-
device_map="auto",
|
| 29 |
-
)
|
| 30 |
-
_ = model.load_state_dict(new_ckpt, strict=False)
|
| 31 |
-
print(f"Saving HF model to {new_folder_path}")
|
| 32 |
-
model.save_pretrained(new_folder_path)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
if __name__ == "__main__":
|
| 36 |
-
parser = argparse.ArgumentParser()
|
| 37 |
-
parser.add_argument(
|
| 38 |
-
"--old_ckpt_path",
|
| 39 |
-
"-old",
|
| 40 |
-
type=str,
|
| 41 |
-
required=True,
|
| 42 |
-
help="Path to the pt checkpoint",
|
| 43 |
-
)
|
| 44 |
-
parser.add_argument(
|
| 45 |
-
"--new_hf_path",
|
| 46 |
-
"-new",
|
| 47 |
-
type=str,
|
| 48 |
-
required=True,
|
| 49 |
-
help="Path to the hf folder",
|
| 50 |
-
)
|
| 51 |
-
parser.add_argument(
|
| 52 |
-
"--pretrained_model_path",
|
| 53 |
-
"-pretrained",
|
| 54 |
-
type=str,
|
| 55 |
-
required=True,
|
| 56 |
-
help="Path to the pretrained model folder",
|
| 57 |
-
)
|
| 58 |
-
args = parser.parse_args()
|
| 59 |
-
if not os.path.exists(os.path.dirname(args.new_hf_path)):
|
| 60 |
-
os.makedirs(os.path.dirname(args.new_hf_path))
|
| 61 |
-
dump_hf_model(args.pretrained_model_path, args.old_ckpt_path, args.new_hf_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/converting_flamingo_to_lora.py
DELETED
|
@@ -1,68 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import torch
|
| 3 |
-
import sys
|
| 4 |
-
|
| 5 |
-
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 6 |
-
from peft import get_peft_model, LoraConfig, TaskType
|
| 7 |
-
|
| 8 |
-
MODEL_CLASSES = {
|
| 9 |
-
"LlamaForCausalLM": "llama",
|
| 10 |
-
"OPTForCausalLM": "opt",
|
| 11 |
-
"GPTJForCausalLM": "gptj",
|
| 12 |
-
"GPTNeoXForCausalLM": "gpt_neox",
|
| 13 |
-
"MPTForCausalLM": "mpt",
|
| 14 |
-
}
|
| 15 |
-
|
| 16 |
-
# Define argument parser
|
| 17 |
-
parser = argparse.ArgumentParser(description="Load a model with specified precision and save it to a specified path.")
|
| 18 |
-
|
| 19 |
-
# Add arguments
|
| 20 |
-
parser.add_argument(
|
| 21 |
-
"--checkpoint_path",
|
| 22 |
-
type=str,
|
| 23 |
-
help="Path to the pre-trained model checkpoint.",
|
| 24 |
-
default="",
|
| 25 |
-
)
|
| 26 |
-
parser.add_argument(
|
| 27 |
-
"--save_path",
|
| 28 |
-
type=str,
|
| 29 |
-
default="",
|
| 30 |
-
help="Path to the converted model checkpoint.",
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
# Parse the input arguments
|
| 34 |
-
args = parser.parse_args()
|
| 35 |
-
|
| 36 |
-
load_bit = "bf16"
|
| 37 |
-
if load_bit == "fp16":
|
| 38 |
-
precision = {"torch_dtype": torch.float16}
|
| 39 |
-
elif load_bit == "bf16":
|
| 40 |
-
precision = {"torch_dtype": torch.bfloat16}
|
| 41 |
-
|
| 42 |
-
# Load the model
|
| 43 |
-
model = FlamingoForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map="auto", **precision)
|
| 44 |
-
|
| 45 |
-
# adding lora
|
| 46 |
-
standard_modules = ["q_proj", "v_proj"]
|
| 47 |
-
lang_encoder_short_name = MODEL_CLASSES[model.config.text_config.architectures[0]]
|
| 48 |
-
model_to_lora_modules = {
|
| 49 |
-
"llama": standard_modules,
|
| 50 |
-
"opt": standard_modules,
|
| 51 |
-
"gptj": standard_modules,
|
| 52 |
-
"gpt_neox": ["query_key_value"],
|
| 53 |
-
"mpt": ["Wqkv"],
|
| 54 |
-
}
|
| 55 |
-
lora_config = LoraConfig(
|
| 56 |
-
r=16,
|
| 57 |
-
lora_alpha=32,
|
| 58 |
-
lora_dropout=0.05,
|
| 59 |
-
task_type=TaskType.CAUSAL_LM,
|
| 60 |
-
target_modules=model_to_lora_modules[lang_encoder_short_name],
|
| 61 |
-
)
|
| 62 |
-
model.config.update({"lora_config": {"r": 16, "lora_alpha": 32, "lora_dropout": 0.05}})
|
| 63 |
-
model.lang_encoder = get_peft_model(model.lang_encoder, lora_config)
|
| 64 |
-
model.lang_encoder.print_trainable_parameters()
|
| 65 |
-
|
| 66 |
-
# Save the model
|
| 67 |
-
checkpoint_path = args.save_path
|
| 68 |
-
FlamingoForConditionalGeneration.save_pretrained(model, checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/falcon/__init__.py
DELETED
|
File without changes
|
mllm/flamingo/falcon/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (209 Bytes)
|
|
|
mllm/flamingo/falcon/__pycache__/configuration_RW.cpython-39.pyc
DELETED
|
Binary file (1.86 kB)
|
|
|
mllm/flamingo/falcon/__pycache__/modelling_RW.cpython-39.pyc
DELETED
|
Binary file (28.5 kB)
|
|
|
mllm/flamingo/falcon/configuration_RW.py
DELETED
|
@@ -1,79 +0,0 @@
|
|
| 1 |
-
# coding=utf-8
|
| 2 |
-
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
|
| 3 |
-
#
|
| 4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
-
# you may not use this file except in compliance with the License.
|
| 6 |
-
# You may obtain a copy of the License at
|
| 7 |
-
#
|
| 8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
-
#
|
| 10 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
-
# See the License for the specific language governing permissions and
|
| 14 |
-
# limitations under the License.
|
| 15 |
-
""" Bloom configuration"""
|
| 16 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 17 |
-
from transformers.utils import logging
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
logger = logging.get_logger(__name__)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class RWConfig(PretrainedConfig):
|
| 24 |
-
model_type = "RefinedWebModel"
|
| 25 |
-
keys_to_ignore_at_inference = ["past_key_values"]
|
| 26 |
-
attribute_map = {
|
| 27 |
-
"num_hidden_layers": "n_layer",
|
| 28 |
-
"num_attention_heads": "n_head",
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
def __init__(
|
| 32 |
-
self,
|
| 33 |
-
vocab_size=250880,
|
| 34 |
-
hidden_size=64,
|
| 35 |
-
n_layer=2,
|
| 36 |
-
n_head=8,
|
| 37 |
-
layer_norm_epsilon=1e-5,
|
| 38 |
-
initializer_range=0.02,
|
| 39 |
-
use_cache=True,
|
| 40 |
-
bos_token_id=1,
|
| 41 |
-
eos_token_id=2,
|
| 42 |
-
apply_residual_connection_post_layernorm=False,
|
| 43 |
-
hidden_dropout=0.0,
|
| 44 |
-
attention_dropout=0.0,
|
| 45 |
-
multi_query=False,
|
| 46 |
-
alibi=False,
|
| 47 |
-
bias=False,
|
| 48 |
-
parallel_attn=False,
|
| 49 |
-
**kwargs,
|
| 50 |
-
):
|
| 51 |
-
self.vocab_size = vocab_size
|
| 52 |
-
# Backward compatibility with n_embed kwarg
|
| 53 |
-
n_embed = kwargs.pop("n_embed", None)
|
| 54 |
-
self.hidden_size = hidden_size if n_embed is None else n_embed
|
| 55 |
-
self.n_layer = n_layer
|
| 56 |
-
self.n_head = n_head
|
| 57 |
-
self.layer_norm_epsilon = layer_norm_epsilon
|
| 58 |
-
self.initializer_range = initializer_range
|
| 59 |
-
self.use_cache = use_cache
|
| 60 |
-
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
| 61 |
-
self.hidden_dropout = hidden_dropout
|
| 62 |
-
self.attention_dropout = attention_dropout
|
| 63 |
-
|
| 64 |
-
self.bos_token_id = bos_token_id
|
| 65 |
-
self.eos_token_id = eos_token_id
|
| 66 |
-
self.multi_query = multi_query
|
| 67 |
-
self.alibi = alibi
|
| 68 |
-
self.bias = bias
|
| 69 |
-
self.parallel_attn = parallel_attn
|
| 70 |
-
|
| 71 |
-
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 72 |
-
|
| 73 |
-
@property
|
| 74 |
-
def head_dim(self):
|
| 75 |
-
return self.hidden_size // self.n_head
|
| 76 |
-
|
| 77 |
-
@property
|
| 78 |
-
def rotary(self):
|
| 79 |
-
return not self.alibi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/falcon/modelling_RW.py
DELETED
|
@@ -1,1064 +0,0 @@
|
|
| 1 |
-
# port of models described in RW
|
| 2 |
-
# We use the bloom model as a starting point for these model.
|
| 3 |
-
# Please refer to the bloom models for usage instructions.
|
| 4 |
-
|
| 5 |
-
import math
|
| 6 |
-
import warnings
|
| 7 |
-
from typing import Optional, Tuple, Union
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.utils.checkpoint
|
| 11 |
-
from torch import nn
|
| 12 |
-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
| 13 |
-
from torch.nn import functional as F
|
| 14 |
-
|
| 15 |
-
from transformers.modeling_outputs import (
|
| 16 |
-
BaseModelOutputWithPastAndCrossAttentions,
|
| 17 |
-
CausalLMOutputWithCrossAttentions,
|
| 18 |
-
QuestionAnsweringModelOutput,
|
| 19 |
-
SequenceClassifierOutputWithPast,
|
| 20 |
-
TokenClassifierOutput,
|
| 21 |
-
)
|
| 22 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 23 |
-
from transformers.utils import logging
|
| 24 |
-
from .configuration_RW import RWConfig
|
| 25 |
-
|
| 26 |
-
logger = logging.get_logger(__name__)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
|
| 30 |
-
# In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
|
| 31 |
-
class Linear(nn.Linear):
|
| 32 |
-
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 33 |
-
ret = input @ self.weight.T
|
| 34 |
-
if self.bias is None:
|
| 35 |
-
return ret
|
| 36 |
-
else:
|
| 37 |
-
return ret + self.bias
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
from einops import rearrange
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
|
| 44 |
-
def rotate_half(x):
|
| 45 |
-
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 46 |
-
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class RotaryEmbedding(torch.nn.Module):
|
| 50 |
-
"""Implementation of RotaryEmbedding from GPT-NeoX.
|
| 51 |
-
This implementation is design to operate on queries and keys that are compatible with
|
| 52 |
-
[batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
def __init__(
|
| 56 |
-
self,
|
| 57 |
-
head_dim: int,
|
| 58 |
-
base=10000,
|
| 59 |
-
):
|
| 60 |
-
super().__init__()
|
| 61 |
-
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 62 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 63 |
-
self.head_dim = head_dim
|
| 64 |
-
self.seq_len_cached = None
|
| 65 |
-
self.batch_size_cached = None
|
| 66 |
-
self.cos_cached: torch.Tensor | None = None
|
| 67 |
-
self.sin_cached: torch.Tensor | None = None
|
| 68 |
-
|
| 69 |
-
def cos_sin(
|
| 70 |
-
self,
|
| 71 |
-
seq_len: int,
|
| 72 |
-
device="cuda",
|
| 73 |
-
dtype=torch.bfloat16,
|
| 74 |
-
) -> torch.Tensor:
|
| 75 |
-
if seq_len != self.seq_len_cached:
|
| 76 |
-
self.seq_len_cached = seq_len
|
| 77 |
-
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
| 78 |
-
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 79 |
-
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
| 80 |
-
|
| 81 |
-
if dtype in [torch.float16, torch.bfloat16]:
|
| 82 |
-
emb = emb.float()
|
| 83 |
-
|
| 84 |
-
self.cos_cached = emb.cos()[None, :, :]
|
| 85 |
-
self.sin_cached = emb.sin()[None, :, :]
|
| 86 |
-
|
| 87 |
-
self.cos_cached = self.cos_cached.type(dtype)
|
| 88 |
-
self.sin_cached = self.sin_cached.type(dtype)
|
| 89 |
-
|
| 90 |
-
return self.cos_cached, self.sin_cached
|
| 91 |
-
|
| 92 |
-
def forward(self, q, k):
|
| 93 |
-
batch, seq_len, head_dim = q.shape
|
| 94 |
-
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
| 95 |
-
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def _make_causal_mask(input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int) -> torch.BoolTensor:
|
| 99 |
-
batch_size, target_length = input_ids_shape
|
| 100 |
-
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
| 101 |
-
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
| 102 |
-
seq_ids = torch.arange(target_length, device=device)
|
| 103 |
-
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
|
| 104 |
-
|
| 105 |
-
if past_key_values_length > 0:
|
| 106 |
-
mask[:, :past_key_values_length] = False
|
| 107 |
-
|
| 108 |
-
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
| 109 |
-
return expanded_mask
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
| 113 |
-
batch_size, src_length = mask.shape
|
| 114 |
-
tgt_length = tgt_length if tgt_length is not None else src_length
|
| 115 |
-
|
| 116 |
-
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
| 117 |
-
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
| 121 |
-
batch_size, seq_length = attention_mask.shape
|
| 122 |
-
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
| 123 |
-
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32)
|
| 124 |
-
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
| 125 |
-
slopes = torch.pow(base, powers)
|
| 126 |
-
|
| 127 |
-
if closest_power_of_2 != num_heads:
|
| 128 |
-
extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32)
|
| 129 |
-
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
| 130 |
-
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
| 131 |
-
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 132 |
-
|
| 133 |
-
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
| 134 |
-
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
| 135 |
-
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
| 136 |
-
# => the query_length dimension will then be broadcasted correctly
|
| 137 |
-
# This is more or less identical to T5's relative position bias:
|
| 138 |
-
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
| 139 |
-
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
| 140 |
-
alibi = slopes[..., None].bfloat16() * arange_tensor
|
| 141 |
-
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
|
| 145 |
-
out = F.dropout(x, p=prob, training=training)
|
| 146 |
-
out = residual + out
|
| 147 |
-
return out
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
class Attention(nn.Module):
|
| 151 |
-
def __init__(self, config: RWConfig):
|
| 152 |
-
super().__init__()
|
| 153 |
-
|
| 154 |
-
self.hidden_size = config.hidden_size
|
| 155 |
-
self.num_heads = config.n_head
|
| 156 |
-
self.head_dim = self.hidden_size // self.num_heads
|
| 157 |
-
self.split_size = self.hidden_size
|
| 158 |
-
self.hidden_dropout = config.hidden_dropout
|
| 159 |
-
|
| 160 |
-
if self.head_dim * self.num_heads != self.hidden_size:
|
| 161 |
-
raise ValueError(f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" f" {self.num_heads}).")
|
| 162 |
-
|
| 163 |
-
self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
|
| 164 |
-
|
| 165 |
-
# Layer-wise attention scaling
|
| 166 |
-
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
| 167 |
-
self.beta = self.inv_norm_factor
|
| 168 |
-
|
| 169 |
-
self.query_key_value = Linear(
|
| 170 |
-
self.hidden_size,
|
| 171 |
-
3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
|
| 172 |
-
bias=config.bias,
|
| 173 |
-
)
|
| 174 |
-
self.multi_query = config.multi_query
|
| 175 |
-
self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
|
| 176 |
-
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 177 |
-
self.num_kv = config.n_head if not self.multi_query else 1
|
| 178 |
-
|
| 179 |
-
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 180 |
-
"""
|
| 181 |
-
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
| 182 |
-
storage as `fused_qkv`
|
| 183 |
-
|
| 184 |
-
Args:
|
| 185 |
-
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
| 186 |
-
|
| 187 |
-
Returns:
|
| 188 |
-
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
| 189 |
-
value: [batch_size, seq_length, num_heads, head_dim]
|
| 190 |
-
"""
|
| 191 |
-
if not self.multi_query:
|
| 192 |
-
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
| 193 |
-
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
|
| 194 |
-
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
|
| 195 |
-
else:
|
| 196 |
-
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
| 197 |
-
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
|
| 198 |
-
return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
|
| 199 |
-
|
| 200 |
-
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 201 |
-
"""
|
| 202 |
-
Merge heads together over the last dimenstion
|
| 203 |
-
|
| 204 |
-
Args:
|
| 205 |
-
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
| 206 |
-
|
| 207 |
-
Returns:
|
| 208 |
-
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
| 209 |
-
"""
|
| 210 |
-
# What we want to achieve is:
|
| 211 |
-
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
| 212 |
-
batch_size_and_num_heads, seq_length, _ = x.shape
|
| 213 |
-
batch_size = batch_size_and_num_heads // self.num_heads
|
| 214 |
-
|
| 215 |
-
# First view to decompose the batch size
|
| 216 |
-
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
| 217 |
-
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
|
| 218 |
-
|
| 219 |
-
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
| 220 |
-
x = x.permute(0, 2, 1, 3)
|
| 221 |
-
|
| 222 |
-
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
| 223 |
-
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
|
| 224 |
-
|
| 225 |
-
def forward(
|
| 226 |
-
self,
|
| 227 |
-
hidden_states: torch.Tensor,
|
| 228 |
-
alibi: torch.Tensor,
|
| 229 |
-
attention_mask: torch.Tensor,
|
| 230 |
-
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 231 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 232 |
-
use_cache: bool = False,
|
| 233 |
-
output_attentions: bool = False,
|
| 234 |
-
):
|
| 235 |
-
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
| 236 |
-
|
| 237 |
-
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
| 238 |
-
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
| 239 |
-
|
| 240 |
-
batch_size, q_length, _, _ = query_layer.shape
|
| 241 |
-
|
| 242 |
-
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 243 |
-
key_layer = key_layer.transpose(1, 2).reshape(
|
| 244 |
-
batch_size * self.num_kv,
|
| 245 |
-
q_length,
|
| 246 |
-
self.head_dim,
|
| 247 |
-
)
|
| 248 |
-
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
| 249 |
-
|
| 250 |
-
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
| 251 |
-
|
| 252 |
-
if layer_past is not None:
|
| 253 |
-
past_key, past_value = layer_past
|
| 254 |
-
# concatenate along seq_length dimension:
|
| 255 |
-
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
| 256 |
-
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
| 257 |
-
key_layer = torch.cat((past_key, key_layer), dim=1)
|
| 258 |
-
value_layer = torch.cat((past_value, value_layer), dim=1)
|
| 259 |
-
|
| 260 |
-
_, kv_length, _ = key_layer.shape
|
| 261 |
-
|
| 262 |
-
if use_cache is True:
|
| 263 |
-
present = (key_layer, value_layer)
|
| 264 |
-
else:
|
| 265 |
-
present = None
|
| 266 |
-
|
| 267 |
-
if alibi is None:
|
| 268 |
-
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 269 |
-
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 270 |
-
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
| 271 |
-
|
| 272 |
-
attn_output = F.scaled_dot_product_attention(query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True)
|
| 273 |
-
|
| 274 |
-
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
| 275 |
-
x = x.permute(0, 2, 1, 3)
|
| 276 |
-
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
|
| 277 |
-
|
| 278 |
-
output_tensor = self.dense(attn_output)
|
| 279 |
-
|
| 280 |
-
outputs = (output_tensor, present)
|
| 281 |
-
assert not output_attentions # not supported.
|
| 282 |
-
return outputs
|
| 283 |
-
else:
|
| 284 |
-
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
| 285 |
-
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
| 286 |
-
|
| 287 |
-
# change view to [batch_size, num_heads, q_length, kv_length]
|
| 288 |
-
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
| 289 |
-
|
| 290 |
-
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
| 291 |
-
input_dtype = attention_scores.dtype
|
| 292 |
-
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
| 293 |
-
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
| 294 |
-
attention_scores = attention_scores.to(torch.float32)
|
| 295 |
-
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
| 296 |
-
attention_probs = F.softmax(
|
| 297 |
-
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
|
| 298 |
-
dim=-1,
|
| 299 |
-
dtype=hidden_states.dtype,
|
| 300 |
-
)
|
| 301 |
-
# [batch_size, num_heads, q_length, kv_length]
|
| 302 |
-
attention_probs = self.attention_dropout(attention_probs)
|
| 303 |
-
|
| 304 |
-
if head_mask is not None:
|
| 305 |
-
attention_probs = attention_probs * head_mask
|
| 306 |
-
|
| 307 |
-
# change view [batch_size x num_heads, q_length, kv_length]
|
| 308 |
-
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
| 309 |
-
|
| 310 |
-
# matmul: [batch_size * num_heads, q_length, head_dim]
|
| 311 |
-
context_layer = attention_probs_reshaped @ value_layer
|
| 312 |
-
|
| 313 |
-
# change view [batch_size, num_heads, q_length, head_dim]
|
| 314 |
-
context_layer = self._merge_heads(context_layer)
|
| 315 |
-
|
| 316 |
-
output_tensor = self.dense(context_layer)
|
| 317 |
-
|
| 318 |
-
outputs = (output_tensor, present)
|
| 319 |
-
if output_attentions:
|
| 320 |
-
outputs += (attention_probs,)
|
| 321 |
-
|
| 322 |
-
return outputs
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
class MLP(nn.Module):
|
| 326 |
-
def __init__(self, config: RWConfig):
|
| 327 |
-
super().__init__()
|
| 328 |
-
hidden_size = config.hidden_size
|
| 329 |
-
|
| 330 |
-
self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
|
| 331 |
-
self.act = nn.GELU()
|
| 332 |
-
self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
|
| 333 |
-
self.hidden_dropout = config.hidden_dropout
|
| 334 |
-
|
| 335 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 336 |
-
x = self.act(self.dense_h_to_4h(x))
|
| 337 |
-
x = self.dense_4h_to_h(x)
|
| 338 |
-
return x
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
class DecoderLayer(nn.Module):
|
| 342 |
-
def __init__(self, config: RWConfig):
|
| 343 |
-
super().__init__()
|
| 344 |
-
hidden_size = config.hidden_size
|
| 345 |
-
|
| 346 |
-
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 347 |
-
self.num_heads = config.n_head
|
| 348 |
-
self.self_attention = Attention(config)
|
| 349 |
-
|
| 350 |
-
if not config.parallel_attn:
|
| 351 |
-
# unused if parallel attn
|
| 352 |
-
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 353 |
-
|
| 354 |
-
self.mlp = MLP(config)
|
| 355 |
-
|
| 356 |
-
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
| 357 |
-
self.hidden_dropout = config.hidden_dropout
|
| 358 |
-
|
| 359 |
-
self.config = config
|
| 360 |
-
|
| 361 |
-
def forward(
|
| 362 |
-
self,
|
| 363 |
-
hidden_states: torch.Tensor,
|
| 364 |
-
alibi: torch.Tensor,
|
| 365 |
-
attention_mask: torch.Tensor,
|
| 366 |
-
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 367 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 368 |
-
use_cache: bool = False,
|
| 369 |
-
output_attentions: bool = False,
|
| 370 |
-
):
|
| 371 |
-
layernorm_output = self.input_layernorm(hidden_states)
|
| 372 |
-
residual = hidden_states
|
| 373 |
-
|
| 374 |
-
# Self attention.
|
| 375 |
-
attn_outputs = self.self_attention(
|
| 376 |
-
layernorm_output,
|
| 377 |
-
layer_past=layer_past,
|
| 378 |
-
attention_mask=attention_mask,
|
| 379 |
-
alibi=alibi,
|
| 380 |
-
head_mask=head_mask,
|
| 381 |
-
use_cache=use_cache,
|
| 382 |
-
output_attentions=output_attentions,
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
attention_output = attn_outputs[0]
|
| 386 |
-
|
| 387 |
-
if not self.config.parallel_attn:
|
| 388 |
-
residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
|
| 389 |
-
layernorm_output = self.post_attention_layernorm(residual)
|
| 390 |
-
|
| 391 |
-
outputs = attn_outputs[1:]
|
| 392 |
-
|
| 393 |
-
# MLP.
|
| 394 |
-
mlp_output = self.mlp(layernorm_output)
|
| 395 |
-
|
| 396 |
-
if self.config.parallel_attn:
|
| 397 |
-
mlp_output += attention_output
|
| 398 |
-
|
| 399 |
-
output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
|
| 400 |
-
|
| 401 |
-
if use_cache:
|
| 402 |
-
outputs = (output,) + outputs
|
| 403 |
-
else:
|
| 404 |
-
outputs = (output,) + outputs[1:]
|
| 405 |
-
|
| 406 |
-
return outputs # hidden_states, present, attentions
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
class RWPreTrainedModel(PreTrainedModel):
|
| 410 |
-
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 411 |
-
"""
|
| 412 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 413 |
-
models.
|
| 414 |
-
"""
|
| 415 |
-
|
| 416 |
-
config_class = RWConfig
|
| 417 |
-
base_model_prefix = "transformer"
|
| 418 |
-
supports_gradient_checkpointing = True
|
| 419 |
-
_no_split_modules = ["DecoderLayer"]
|
| 420 |
-
|
| 421 |
-
def __init__(self, *inputs, **kwargs):
|
| 422 |
-
super().__init__(*inputs, **kwargs)
|
| 423 |
-
|
| 424 |
-
def _init_weights(self, module: nn.Module):
|
| 425 |
-
"""Initialize the weights."""
|
| 426 |
-
if isinstance(module, nn.Linear) or isinstance(module, Linear):
|
| 427 |
-
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 428 |
-
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 429 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 430 |
-
if module.bias is not None:
|
| 431 |
-
module.bias.data.zero_()
|
| 432 |
-
elif isinstance(module, nn.Embedding):
|
| 433 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 434 |
-
if module.padding_idx is not None:
|
| 435 |
-
module.weight.data[module.padding_idx].zero_()
|
| 436 |
-
elif isinstance(module, LayerNorm):
|
| 437 |
-
module.bias.data.zero_()
|
| 438 |
-
module.weight.data.fill_(1.0)
|
| 439 |
-
|
| 440 |
-
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
|
| 441 |
-
if isinstance(module, RWModel):
|
| 442 |
-
module.gradient_checkpointing = value
|
| 443 |
-
|
| 444 |
-
@staticmethod
|
| 445 |
-
def _convert_to_standard_cache(past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
| 446 |
-
"""
|
| 447 |
-
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
| 448 |
-
num_heads, ...]))
|
| 449 |
-
"""
|
| 450 |
-
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
| 451 |
-
num_heads = batch_size_times_num_heads // batch_size
|
| 452 |
-
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
| 453 |
-
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
| 454 |
-
return tuple(
|
| 455 |
-
(
|
| 456 |
-
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
| 457 |
-
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
| 458 |
-
)
|
| 459 |
-
for layer_past in past_key_value
|
| 460 |
-
)
|
| 461 |
-
|
| 462 |
-
@staticmethod
|
| 463 |
-
def _convert_to_rw_cache(past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
| 464 |
-
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
| 465 |
-
batch_size_times_num_heads = batch_size * num_heads
|
| 466 |
-
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
| 467 |
-
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
| 468 |
-
return tuple(
|
| 469 |
-
(
|
| 470 |
-
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
| 471 |
-
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
| 472 |
-
)
|
| 473 |
-
for layer_past in past_key_value
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
class RWModel(RWPreTrainedModel):
|
| 478 |
-
def __init__(self, config: RWConfig):
|
| 479 |
-
super().__init__(config)
|
| 480 |
-
|
| 481 |
-
self.embed_dim = config.hidden_size
|
| 482 |
-
self.num_heads = config.n_head
|
| 483 |
-
self.alibi = config.alibi
|
| 484 |
-
|
| 485 |
-
# Embedding + LN Embedding
|
| 486 |
-
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
| 487 |
-
|
| 488 |
-
# Transformer blocks
|
| 489 |
-
self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 490 |
-
|
| 491 |
-
# Final Layer Norm
|
| 492 |
-
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
| 493 |
-
|
| 494 |
-
self.gradient_checkpointing = False
|
| 495 |
-
|
| 496 |
-
# Initialize weights and apply final processing
|
| 497 |
-
self.post_init()
|
| 498 |
-
|
| 499 |
-
def get_input_embeddings(self):
|
| 500 |
-
return self.word_embeddings
|
| 501 |
-
|
| 502 |
-
def _prepare_attn_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
|
| 503 |
-
# create causal mask
|
| 504 |
-
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
| 505 |
-
combined_attention_mask = None
|
| 506 |
-
device = attention_mask.device
|
| 507 |
-
_, src_length = input_shape
|
| 508 |
-
|
| 509 |
-
if src_length > 1:
|
| 510 |
-
combined_attention_mask = _make_causal_mask(input_shape, device=device, past_key_values_length=past_key_values_length)
|
| 511 |
-
|
| 512 |
-
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
| 513 |
-
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
| 514 |
-
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
| 515 |
-
|
| 516 |
-
return combined_attention_mask
|
| 517 |
-
|
| 518 |
-
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
| 519 |
-
self.word_embeddings = new_embeddings
|
| 520 |
-
|
| 521 |
-
def forward(
|
| 522 |
-
self,
|
| 523 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 524 |
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 525 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 526 |
-
head_mask: Optional[torch.LongTensor] = None,
|
| 527 |
-
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 528 |
-
use_cache: Optional[bool] = None,
|
| 529 |
-
output_attentions: Optional[bool] = None,
|
| 530 |
-
output_hidden_states: Optional[bool] = None,
|
| 531 |
-
return_dict: Optional[bool] = None,
|
| 532 |
-
**deprecated_arguments,
|
| 533 |
-
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
| 534 |
-
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 535 |
-
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
| 536 |
-
warnings.warn(
|
| 537 |
-
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
|
| 538 |
-
FutureWarning,
|
| 539 |
-
)
|
| 540 |
-
if len(deprecated_arguments) > 0:
|
| 541 |
-
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 542 |
-
|
| 543 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 544 |
-
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 545 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 546 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 547 |
-
|
| 548 |
-
if input_ids is not None and inputs_embeds is not None:
|
| 549 |
-
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 550 |
-
elif input_ids is not None:
|
| 551 |
-
batch_size, seq_length = input_ids.shape
|
| 552 |
-
elif inputs_embeds is not None:
|
| 553 |
-
batch_size, seq_length, _ = inputs_embeds.shape
|
| 554 |
-
else:
|
| 555 |
-
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 556 |
-
|
| 557 |
-
if past_key_values is None:
|
| 558 |
-
past_key_values = tuple([None] * len(self.h))
|
| 559 |
-
|
| 560 |
-
# Prepare head mask if needed
|
| 561 |
-
# 1.0 in head_mask indicate we keep the head
|
| 562 |
-
# attention_probs has shape batch_size x num_heads x N x N
|
| 563 |
-
# head_mask has shape n_layer x batch x num_heads x N x N
|
| 564 |
-
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 565 |
-
|
| 566 |
-
if inputs_embeds is None:
|
| 567 |
-
inputs_embeds = self.word_embeddings(input_ids)
|
| 568 |
-
|
| 569 |
-
hidden_states = inputs_embeds
|
| 570 |
-
|
| 571 |
-
presents = () if use_cache else None
|
| 572 |
-
all_self_attentions = () if output_attentions else None
|
| 573 |
-
all_hidden_states = () if output_hidden_states else None
|
| 574 |
-
|
| 575 |
-
# Compute alibi tensor: check build_alibi_tensor documentation
|
| 576 |
-
seq_length_with_past = seq_length
|
| 577 |
-
past_key_values_length = 0
|
| 578 |
-
if past_key_values[0] is not None:
|
| 579 |
-
past_key_values_length = past_key_values[0][0].shape[2]
|
| 580 |
-
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 581 |
-
if attention_mask is None:
|
| 582 |
-
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
| 583 |
-
else:
|
| 584 |
-
attention_mask = attention_mask.to(hidden_states.device)
|
| 585 |
-
|
| 586 |
-
if self.alibi:
|
| 587 |
-
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
| 588 |
-
else:
|
| 589 |
-
alibi = None
|
| 590 |
-
|
| 591 |
-
causal_mask = self._prepare_attn_mask(
|
| 592 |
-
attention_mask,
|
| 593 |
-
input_shape=(batch_size, seq_length),
|
| 594 |
-
past_key_values_length=past_key_values_length,
|
| 595 |
-
)
|
| 596 |
-
|
| 597 |
-
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 598 |
-
if output_hidden_states:
|
| 599 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 600 |
-
|
| 601 |
-
if self.gradient_checkpointing and self.training:
|
| 602 |
-
if use_cache:
|
| 603 |
-
logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
| 604 |
-
use_cache = False
|
| 605 |
-
|
| 606 |
-
def create_custom_forward(module):
|
| 607 |
-
def custom_forward(*inputs):
|
| 608 |
-
# None for past_key_value
|
| 609 |
-
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
| 610 |
-
|
| 611 |
-
return custom_forward
|
| 612 |
-
|
| 613 |
-
outputs = torch.utils.checkpoint.checkpoint(
|
| 614 |
-
create_custom_forward(block),
|
| 615 |
-
hidden_states,
|
| 616 |
-
alibi,
|
| 617 |
-
causal_mask,
|
| 618 |
-
head_mask[i],
|
| 619 |
-
)
|
| 620 |
-
else:
|
| 621 |
-
outputs = block(
|
| 622 |
-
hidden_states,
|
| 623 |
-
layer_past=layer_past,
|
| 624 |
-
attention_mask=causal_mask,
|
| 625 |
-
head_mask=head_mask[i],
|
| 626 |
-
use_cache=use_cache,
|
| 627 |
-
output_attentions=output_attentions,
|
| 628 |
-
alibi=alibi,
|
| 629 |
-
)
|
| 630 |
-
|
| 631 |
-
hidden_states = outputs[0]
|
| 632 |
-
if use_cache is True:
|
| 633 |
-
presents = presents + (outputs[1],)
|
| 634 |
-
|
| 635 |
-
if output_attentions:
|
| 636 |
-
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
| 637 |
-
|
| 638 |
-
# Add last hidden state
|
| 639 |
-
hidden_states = self.ln_f(hidden_states)
|
| 640 |
-
|
| 641 |
-
if output_hidden_states:
|
| 642 |
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 643 |
-
|
| 644 |
-
if not return_dict:
|
| 645 |
-
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
| 646 |
-
|
| 647 |
-
return BaseModelOutputWithPastAndCrossAttentions(
|
| 648 |
-
last_hidden_state=hidden_states,
|
| 649 |
-
past_key_values=presents,
|
| 650 |
-
hidden_states=all_hidden_states,
|
| 651 |
-
attentions=all_self_attentions,
|
| 652 |
-
)
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
class RWForCausalLM(RWPreTrainedModel):
|
| 656 |
-
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 657 |
-
|
| 658 |
-
def __init__(self, config: RWConfig):
|
| 659 |
-
super().__init__(config)
|
| 660 |
-
self.transformer = RWModel(config)
|
| 661 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 662 |
-
|
| 663 |
-
# Initialize weights and apply final processing
|
| 664 |
-
self.post_init()
|
| 665 |
-
|
| 666 |
-
def get_output_embeddings(self):
|
| 667 |
-
return self.lm_head
|
| 668 |
-
|
| 669 |
-
def set_output_embeddings(self, new_embeddings: torch.Tensor):
|
| 670 |
-
self.lm_head = new_embeddings
|
| 671 |
-
|
| 672 |
-
def prepare_inputs_for_generation(
|
| 673 |
-
self,
|
| 674 |
-
input_ids: torch.LongTensor,
|
| 675 |
-
past: Optional[torch.Tensor] = None,
|
| 676 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 677 |
-
**kwargs,
|
| 678 |
-
) -> dict:
|
| 679 |
-
# only last token for input_ids if past is not None
|
| 680 |
-
if past:
|
| 681 |
-
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 682 |
-
|
| 683 |
-
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
| 684 |
-
if past[0][0].shape[0] == input_ids.shape[0]:
|
| 685 |
-
past = self._convert_to_rw_cache(past)
|
| 686 |
-
|
| 687 |
-
return {
|
| 688 |
-
"input_ids": input_ids,
|
| 689 |
-
"past_key_values": past,
|
| 690 |
-
"use_cache": kwargs.get("use_cache"),
|
| 691 |
-
"attention_mask": attention_mask,
|
| 692 |
-
}
|
| 693 |
-
|
| 694 |
-
def forward(
|
| 695 |
-
self,
|
| 696 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 697 |
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 698 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 699 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 700 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
| 701 |
-
labels: Optional[torch.Tensor] = None,
|
| 702 |
-
use_cache: Optional[bool] = None,
|
| 703 |
-
output_attentions: Optional[bool] = None,
|
| 704 |
-
output_hidden_states: Optional[bool] = None,
|
| 705 |
-
return_dict: Optional[bool] = None,
|
| 706 |
-
**deprecated_arguments,
|
| 707 |
-
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 708 |
-
r"""
|
| 709 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 710 |
-
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
| 711 |
-
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
| 712 |
-
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
| 713 |
-
"""
|
| 714 |
-
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 715 |
-
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
| 716 |
-
warnings.warn(
|
| 717 |
-
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
|
| 718 |
-
FutureWarning,
|
| 719 |
-
)
|
| 720 |
-
if len(deprecated_arguments) > 0:
|
| 721 |
-
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 722 |
-
|
| 723 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 724 |
-
|
| 725 |
-
transformer_outputs = self.transformer(
|
| 726 |
-
input_ids,
|
| 727 |
-
past_key_values=past_key_values,
|
| 728 |
-
attention_mask=attention_mask,
|
| 729 |
-
head_mask=head_mask,
|
| 730 |
-
inputs_embeds=inputs_embeds,
|
| 731 |
-
use_cache=use_cache,
|
| 732 |
-
output_attentions=output_attentions,
|
| 733 |
-
output_hidden_states=output_hidden_states,
|
| 734 |
-
return_dict=return_dict,
|
| 735 |
-
)
|
| 736 |
-
hidden_states = transformer_outputs[0]
|
| 737 |
-
|
| 738 |
-
lm_logits = self.lm_head(hidden_states)
|
| 739 |
-
|
| 740 |
-
loss = None
|
| 741 |
-
if labels is not None:
|
| 742 |
-
# Shift so that tokens < n predict n
|
| 743 |
-
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 744 |
-
shift_labels = labels[..., 1:].contiguous()
|
| 745 |
-
batch_size, seq_length, vocab_size = shift_logits.shape
|
| 746 |
-
# Flatten the tokens
|
| 747 |
-
loss_fct = CrossEntropyLoss()
|
| 748 |
-
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
|
| 749 |
-
|
| 750 |
-
if not return_dict:
|
| 751 |
-
output = (lm_logits,) + transformer_outputs[1:]
|
| 752 |
-
return ((loss,) + output) if loss is not None else output
|
| 753 |
-
|
| 754 |
-
return CausalLMOutputWithCrossAttentions(
|
| 755 |
-
loss=loss,
|
| 756 |
-
logits=lm_logits,
|
| 757 |
-
past_key_values=transformer_outputs.past_key_values,
|
| 758 |
-
hidden_states=transformer_outputs.hidden_states,
|
| 759 |
-
attentions=transformer_outputs.attentions,
|
| 760 |
-
)
|
| 761 |
-
|
| 762 |
-
def _reorder_cache(self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
| 763 |
-
"""
|
| 764 |
-
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
| 765 |
-
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 766 |
-
beam_idx at every generation step.
|
| 767 |
-
|
| 768 |
-
Output shares the same memory storage as `past`.
|
| 769 |
-
"""
|
| 770 |
-
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
| 771 |
-
|
| 772 |
-
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
| 773 |
-
device_to_beam_idx = {past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past}
|
| 774 |
-
reordered_past = tuple(
|
| 775 |
-
(
|
| 776 |
-
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
| 777 |
-
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
| 778 |
-
)
|
| 779 |
-
for layer_past in standardized_past
|
| 780 |
-
)
|
| 781 |
-
return self._convert_to_rw_cache(reordered_past)
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
class RWForSequenceClassification(RWPreTrainedModel):
|
| 785 |
-
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 786 |
-
|
| 787 |
-
def __init__(self, config: RWConfig):
|
| 788 |
-
super().__init__(config)
|
| 789 |
-
self.num_labels = config.num_labels
|
| 790 |
-
self.transformer = RWModel(config)
|
| 791 |
-
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
| 792 |
-
|
| 793 |
-
# Initialize weights and apply final processing
|
| 794 |
-
self.post_init()
|
| 795 |
-
|
| 796 |
-
def forward(
|
| 797 |
-
self,
|
| 798 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 799 |
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 800 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 801 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 802 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
| 803 |
-
labels: Optional[torch.Tensor] = None,
|
| 804 |
-
use_cache: Optional[bool] = None,
|
| 805 |
-
output_attentions: Optional[bool] = None,
|
| 806 |
-
output_hidden_states: Optional[bool] = None,
|
| 807 |
-
return_dict: Optional[bool] = None,
|
| 808 |
-
**deprecated_arguments,
|
| 809 |
-
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
| 810 |
-
r"""
|
| 811 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 812 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 813 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 814 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 815 |
-
"""
|
| 816 |
-
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 817 |
-
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
| 818 |
-
warnings.warn(
|
| 819 |
-
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
|
| 820 |
-
FutureWarning,
|
| 821 |
-
)
|
| 822 |
-
if len(deprecated_arguments) > 0:
|
| 823 |
-
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 824 |
-
|
| 825 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 826 |
-
|
| 827 |
-
transformer_outputs = self.transformer(
|
| 828 |
-
input_ids,
|
| 829 |
-
past_key_values=past_key_values,
|
| 830 |
-
attention_mask=attention_mask,
|
| 831 |
-
head_mask=head_mask,
|
| 832 |
-
inputs_embeds=inputs_embeds,
|
| 833 |
-
use_cache=use_cache,
|
| 834 |
-
output_attentions=output_attentions,
|
| 835 |
-
output_hidden_states=output_hidden_states,
|
| 836 |
-
return_dict=return_dict,
|
| 837 |
-
)
|
| 838 |
-
|
| 839 |
-
hidden_states = transformer_outputs[0]
|
| 840 |
-
logits = self.score(hidden_states)
|
| 841 |
-
|
| 842 |
-
if input_ids is not None:
|
| 843 |
-
batch_size = input_ids.shape[0]
|
| 844 |
-
else:
|
| 845 |
-
batch_size = inputs_embeds.shape[0]
|
| 846 |
-
|
| 847 |
-
if self.config.pad_token_id is None and batch_size != 1:
|
| 848 |
-
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
| 849 |
-
if self.config.pad_token_id is None:
|
| 850 |
-
sequence_lengths = -1
|
| 851 |
-
else:
|
| 852 |
-
if input_ids is not None:
|
| 853 |
-
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
|
| 854 |
-
else:
|
| 855 |
-
sequence_lengths = -1
|
| 856 |
-
logger.warning(
|
| 857 |
-
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
| 858 |
-
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
| 859 |
-
)
|
| 860 |
-
|
| 861 |
-
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
| 862 |
-
|
| 863 |
-
loss = None
|
| 864 |
-
if labels is not None:
|
| 865 |
-
if self.config.problem_type is None:
|
| 866 |
-
if self.num_labels == 1:
|
| 867 |
-
self.config.problem_type = "regression"
|
| 868 |
-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 869 |
-
self.config.problem_type = "single_label_classification"
|
| 870 |
-
else:
|
| 871 |
-
self.config.problem_type = "multi_label_classification"
|
| 872 |
-
|
| 873 |
-
if self.config.problem_type == "regression":
|
| 874 |
-
loss_fct = MSELoss()
|
| 875 |
-
if self.num_labels == 1:
|
| 876 |
-
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
| 877 |
-
else:
|
| 878 |
-
loss = loss_fct(pooled_logits, labels)
|
| 879 |
-
elif self.config.problem_type == "single_label_classification":
|
| 880 |
-
loss_fct = CrossEntropyLoss()
|
| 881 |
-
loss = loss_fct(pooled_logits, labels)
|
| 882 |
-
elif self.config.problem_type == "multi_label_classification":
|
| 883 |
-
loss_fct = BCEWithLogitsLoss()
|
| 884 |
-
loss = loss_fct(pooled_logits, labels)
|
| 885 |
-
if not return_dict:
|
| 886 |
-
output = (pooled_logits,) + transformer_outputs[1:]
|
| 887 |
-
return ((loss,) + output) if loss is not None else output
|
| 888 |
-
|
| 889 |
-
return SequenceClassifierOutputWithPast(
|
| 890 |
-
loss=loss,
|
| 891 |
-
logits=pooled_logits,
|
| 892 |
-
past_key_values=transformer_outputs.past_key_values,
|
| 893 |
-
hidden_states=transformer_outputs.hidden_states,
|
| 894 |
-
attentions=transformer_outputs.attentions,
|
| 895 |
-
)
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
class RWForTokenClassification(RWPreTrainedModel):
|
| 899 |
-
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 900 |
-
|
| 901 |
-
def __init__(self, config: RWConfig):
|
| 902 |
-
super().__init__(config)
|
| 903 |
-
self.num_labels = config.num_labels
|
| 904 |
-
|
| 905 |
-
self.transformer = RWModel(config)
|
| 906 |
-
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
| 907 |
-
classifier_dropout = config.classifier_dropout
|
| 908 |
-
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
| 909 |
-
classifier_dropout = config.hidden_dropout
|
| 910 |
-
else:
|
| 911 |
-
classifier_dropout = 0.1
|
| 912 |
-
self.dropout = nn.Dropout(classifier_dropout)
|
| 913 |
-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 914 |
-
|
| 915 |
-
# Initialize weights and apply final processing
|
| 916 |
-
self.post_init()
|
| 917 |
-
|
| 918 |
-
def forward(
|
| 919 |
-
self,
|
| 920 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 921 |
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 922 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 923 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 924 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
| 925 |
-
labels: Optional[torch.Tensor] = None,
|
| 926 |
-
use_cache: Optional[bool] = None,
|
| 927 |
-
output_attentions: Optional[bool] = None,
|
| 928 |
-
output_hidden_states: Optional[bool] = None,
|
| 929 |
-
return_dict: Optional[bool] = None,
|
| 930 |
-
**deprecated_arguments,
|
| 931 |
-
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 932 |
-
r"""
|
| 933 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 934 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 935 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 936 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 937 |
-
"""
|
| 938 |
-
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 939 |
-
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
| 940 |
-
warnings.warn(
|
| 941 |
-
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.",
|
| 942 |
-
FutureWarning,
|
| 943 |
-
)
|
| 944 |
-
if len(deprecated_arguments) > 0:
|
| 945 |
-
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 946 |
-
|
| 947 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 948 |
-
|
| 949 |
-
transformer_outputs = self.transformer(
|
| 950 |
-
input_ids,
|
| 951 |
-
past_key_values=past_key_values,
|
| 952 |
-
attention_mask=attention_mask,
|
| 953 |
-
head_mask=head_mask,
|
| 954 |
-
inputs_embeds=inputs_embeds,
|
| 955 |
-
use_cache=use_cache,
|
| 956 |
-
output_attentions=output_attentions,
|
| 957 |
-
output_hidden_states=output_hidden_states,
|
| 958 |
-
return_dict=return_dict,
|
| 959 |
-
)
|
| 960 |
-
|
| 961 |
-
hidden_states = transformer_outputs[0]
|
| 962 |
-
hidden_states = self.dropout(hidden_states)
|
| 963 |
-
logits = self.classifier(hidden_states)
|
| 964 |
-
|
| 965 |
-
loss = None
|
| 966 |
-
if labels is not None:
|
| 967 |
-
batch_size, seq_length = labels.shape
|
| 968 |
-
loss_fct = CrossEntropyLoss()
|
| 969 |
-
loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
|
| 970 |
-
|
| 971 |
-
if not return_dict:
|
| 972 |
-
output = (logits,) + transformer_outputs[2:]
|
| 973 |
-
return ((loss,) + output) if loss is not None else output
|
| 974 |
-
|
| 975 |
-
return TokenClassifierOutput(
|
| 976 |
-
loss=loss,
|
| 977 |
-
logits=logits,
|
| 978 |
-
hidden_states=transformer_outputs.hidden_states,
|
| 979 |
-
attentions=transformer_outputs.attentions,
|
| 980 |
-
)
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
class RWForQuestionAnswering(RWPreTrainedModel):
|
| 984 |
-
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
| 985 |
-
|
| 986 |
-
def __init__(self, config):
|
| 987 |
-
super().__init__(config)
|
| 988 |
-
self.transformer = RWModel(config)
|
| 989 |
-
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
| 990 |
-
|
| 991 |
-
# Initialize weights and apply final processing
|
| 992 |
-
self.post_init()
|
| 993 |
-
|
| 994 |
-
def forward(
|
| 995 |
-
self,
|
| 996 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 997 |
-
attention_mask: Optional[torch.FloatTensor] = None,
|
| 998 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 999 |
-
head_mask: Optional[torch.FloatTensor] = None,
|
| 1000 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1001 |
-
start_positions: Optional[torch.LongTensor] = None,
|
| 1002 |
-
end_positions: Optional[torch.LongTensor] = None,
|
| 1003 |
-
output_attentions: Optional[bool] = None,
|
| 1004 |
-
output_hidden_states: Optional[bool] = None,
|
| 1005 |
-
return_dict: Optional[bool] = None,
|
| 1006 |
-
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
| 1007 |
-
r"""
|
| 1008 |
-
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1009 |
-
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1010 |
-
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1011 |
-
are not taken into account for computing the loss.
|
| 1012 |
-
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1013 |
-
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1014 |
-
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 1015 |
-
are not taken into account for computing the loss.
|
| 1016 |
-
"""
|
| 1017 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1018 |
-
|
| 1019 |
-
outputs = self.transformer(
|
| 1020 |
-
input_ids,
|
| 1021 |
-
attention_mask=attention_mask,
|
| 1022 |
-
position_ids=position_ids,
|
| 1023 |
-
head_mask=head_mask,
|
| 1024 |
-
inputs_embeds=inputs_embeds,
|
| 1025 |
-
output_attentions=output_attentions,
|
| 1026 |
-
output_hidden_states=output_hidden_states,
|
| 1027 |
-
return_dict=return_dict,
|
| 1028 |
-
)
|
| 1029 |
-
|
| 1030 |
-
sequence_output = outputs[0]
|
| 1031 |
-
|
| 1032 |
-
logits = self.qa_outputs(sequence_output)
|
| 1033 |
-
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1034 |
-
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1035 |
-
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1036 |
-
|
| 1037 |
-
total_loss = None
|
| 1038 |
-
if start_positions is not None and end_positions is not None:
|
| 1039 |
-
# If we are on multi-GPU, split add a dimension
|
| 1040 |
-
if len(start_positions.size()) > 1:
|
| 1041 |
-
start_positions = start_positions.squeeze(-1)
|
| 1042 |
-
if len(end_positions.size()) > 1:
|
| 1043 |
-
end_positions = end_positions.squeeze(-1)
|
| 1044 |
-
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1045 |
-
ignored_index = start_logits.size(1)
|
| 1046 |
-
start_positions = start_positions.clamp(0, ignored_index)
|
| 1047 |
-
end_positions = end_positions.clamp(0, ignored_index)
|
| 1048 |
-
|
| 1049 |
-
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1050 |
-
start_loss = loss_fct(start_logits, start_positions)
|
| 1051 |
-
end_loss = loss_fct(end_logits, end_positions)
|
| 1052 |
-
total_loss = (start_loss + end_loss) / 2
|
| 1053 |
-
|
| 1054 |
-
if not return_dict:
|
| 1055 |
-
output = (start_logits, end_logits) + outputs[2:]
|
| 1056 |
-
return ((total_loss,) + output) if total_loss is not None else output
|
| 1057 |
-
|
| 1058 |
-
return QuestionAnsweringModelOutput(
|
| 1059 |
-
loss=total_loss,
|
| 1060 |
-
start_logits=start_logits,
|
| 1061 |
-
end_logits=end_logits,
|
| 1062 |
-
hidden_states=outputs.hidden_states,
|
| 1063 |
-
attentions=outputs.attentions,
|
| 1064 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-falcon-7B.json
DELETED
|
@@ -1,112 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoModel"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 4,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"architectures": [
|
| 10 |
-
"RWForCausalLM"
|
| 11 |
-
],
|
| 12 |
-
"apply_residual_connection_post_layernorm": false,
|
| 13 |
-
"attention_dropout": 0.0,
|
| 14 |
-
"bias": false,
|
| 15 |
-
"bos_token_id": 11,
|
| 16 |
-
"eos_token_id": 11,
|
| 17 |
-
"hidden_dropout": 0.0,
|
| 18 |
-
"hidden_size": 4544,
|
| 19 |
-
"initializer_range": 0.02,
|
| 20 |
-
"layer_norm_epsilon": 1e-05,
|
| 21 |
-
"model_type": "RefinedWebModel",
|
| 22 |
-
"multi_query": true,
|
| 23 |
-
"n_head": 71,
|
| 24 |
-
"n_layer": 32,
|
| 25 |
-
"parallel_attn": true,
|
| 26 |
-
"torch_dtype": "bfloat16",
|
| 27 |
-
"transformers_version": "4.27.4",
|
| 28 |
-
"use_cache": true,
|
| 29 |
-
"vocab_size": 65024
|
| 30 |
-
},
|
| 31 |
-
"tie_word_embeddings": false,
|
| 32 |
-
"torch_dtype": "float32",
|
| 33 |
-
"transformers_version": null,
|
| 34 |
-
"use_media_placement_augmentation": true,
|
| 35 |
-
"vision_config": {
|
| 36 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 37 |
-
"add_cross_attention": false,
|
| 38 |
-
"architectures": null,
|
| 39 |
-
"attention_dropout": 0.0,
|
| 40 |
-
"bad_words_ids": null,
|
| 41 |
-
"begin_suppress_tokens": null,
|
| 42 |
-
"bos_token_id": null,
|
| 43 |
-
"chunk_size_feed_forward": 0,
|
| 44 |
-
"cross_attention_hidden_size": null,
|
| 45 |
-
"decoder_start_token_id": null,
|
| 46 |
-
"diversity_penalty": 0.0,
|
| 47 |
-
"do_sample": false,
|
| 48 |
-
"early_stopping": false,
|
| 49 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 50 |
-
"eos_token_id": null,
|
| 51 |
-
"exponential_decay_length_penalty": null,
|
| 52 |
-
"finetuning_task": null,
|
| 53 |
-
"forced_bos_token_id": null,
|
| 54 |
-
"forced_eos_token_id": null,
|
| 55 |
-
"hidden_act": "quick_gelu",
|
| 56 |
-
"hidden_size": 1024,
|
| 57 |
-
"id2label": {
|
| 58 |
-
"0": "LABEL_0",
|
| 59 |
-
"1": "LABEL_1"
|
| 60 |
-
},
|
| 61 |
-
"image_size": 224,
|
| 62 |
-
"initializer_factor": 1.0,
|
| 63 |
-
"initializer_range": 0.02,
|
| 64 |
-
"intermediate_size": 4096,
|
| 65 |
-
"is_decoder": false,
|
| 66 |
-
"is_encoder_decoder": false,
|
| 67 |
-
"label2id": {
|
| 68 |
-
"LABEL_0": 0,
|
| 69 |
-
"LABEL_1": 1
|
| 70 |
-
},
|
| 71 |
-
"layer_norm_eps": 1e-05,
|
| 72 |
-
"length_penalty": 1.0,
|
| 73 |
-
"max_length": 20,
|
| 74 |
-
"min_length": 0,
|
| 75 |
-
"model_type": "clip_vision_model",
|
| 76 |
-
"no_repeat_ngram_size": 0,
|
| 77 |
-
"num_attention_heads": 16,
|
| 78 |
-
"num_beam_groups": 1,
|
| 79 |
-
"num_beams": 1,
|
| 80 |
-
"num_channels": 3,
|
| 81 |
-
"num_hidden_layers": 24,
|
| 82 |
-
"num_return_sequences": 1,
|
| 83 |
-
"output_attentions": false,
|
| 84 |
-
"output_hidden_states": false,
|
| 85 |
-
"output_scores": false,
|
| 86 |
-
"pad_token_id": null,
|
| 87 |
-
"patch_size": 14,
|
| 88 |
-
"prefix": null,
|
| 89 |
-
"problem_type": null,
|
| 90 |
-
"projection_dim": 512,
|
| 91 |
-
"pruned_heads": {},
|
| 92 |
-
"remove_invalid_values": false,
|
| 93 |
-
"repetition_penalty": 1.0,
|
| 94 |
-
"return_dict": true,
|
| 95 |
-
"return_dict_in_generate": false,
|
| 96 |
-
"sep_token_id": null,
|
| 97 |
-
"suppress_tokens": null,
|
| 98 |
-
"task_specific_params": null,
|
| 99 |
-
"temperature": 1.0,
|
| 100 |
-
"tf_legacy_loss": false,
|
| 101 |
-
"tie_encoder_decoder": false,
|
| 102 |
-
"tie_word_embeddings": true,
|
| 103 |
-
"tokenizer_class": null,
|
| 104 |
-
"top_k": 50,
|
| 105 |
-
"top_p": 1.0,
|
| 106 |
-
"torch_dtype": null,
|
| 107 |
-
"torchscript": false,
|
| 108 |
-
"transformers_version": "4.28.1",
|
| 109 |
-
"typical_p": 1.0,
|
| 110 |
-
"use_bfloat16": false
|
| 111 |
-
}
|
| 112 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-llama2-chat-13B.json
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoForConditionalGeneration"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 8,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"_name_or_path": "meta-llama/Llama-2-13b-chat-hf",
|
| 10 |
-
"architectures": [
|
| 11 |
-
"LlamaForCausalLM"
|
| 12 |
-
],
|
| 13 |
-
"bos_token_id": 1,
|
| 14 |
-
"eos_token_id": 2,
|
| 15 |
-
"hidden_act": "silu",
|
| 16 |
-
"hidden_size": 5120,
|
| 17 |
-
"initializer_range": 0.02,
|
| 18 |
-
"intermediate_size": 13824,
|
| 19 |
-
"max_position_embeddings": 4096,
|
| 20 |
-
"model_type": "llama",
|
| 21 |
-
"num_attention_heads": 40,
|
| 22 |
-
"num_hidden_layers": 40,
|
| 23 |
-
"num_key_value_heads": 40,
|
| 24 |
-
"pad_token_id": 0,
|
| 25 |
-
"pretraining_tp": 1,
|
| 26 |
-
"rms_norm_eps": 1e-05,
|
| 27 |
-
"rope_scaling": null,
|
| 28 |
-
"tie_word_embeddings": false,
|
| 29 |
-
"torch_dtype": "float16",
|
| 30 |
-
"transformers_version": "4.30.1",
|
| 31 |
-
"use_cache": true,
|
| 32 |
-
"vocab_size": 32000
|
| 33 |
-
},
|
| 34 |
-
"torch_dtype": "float32",
|
| 35 |
-
"transformers_version": null,
|
| 36 |
-
"use_media_placement_augmentation": true,
|
| 37 |
-
"vision_config": {
|
| 38 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 39 |
-
"add_cross_attention": false,
|
| 40 |
-
"architectures": null,
|
| 41 |
-
"attention_dropout": 0.0,
|
| 42 |
-
"bad_words_ids": null,
|
| 43 |
-
"begin_suppress_tokens": null,
|
| 44 |
-
"bos_token_id": null,
|
| 45 |
-
"chunk_size_feed_forward": 0,
|
| 46 |
-
"cross_attention_hidden_size": null,
|
| 47 |
-
"decoder_start_token_id": null,
|
| 48 |
-
"diversity_penalty": 0.0,
|
| 49 |
-
"do_sample": false,
|
| 50 |
-
"early_stopping": false,
|
| 51 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 52 |
-
"eos_token_id": null,
|
| 53 |
-
"exponential_decay_length_penalty": null,
|
| 54 |
-
"finetuning_task": null,
|
| 55 |
-
"forced_bos_token_id": null,
|
| 56 |
-
"forced_eos_token_id": null,
|
| 57 |
-
"hidden_act": "quick_gelu",
|
| 58 |
-
"hidden_size": 1024,
|
| 59 |
-
"id2label": {
|
| 60 |
-
"0": "LABEL_0",
|
| 61 |
-
"1": "LABEL_1"
|
| 62 |
-
},
|
| 63 |
-
"image_size": 224,
|
| 64 |
-
"initializer_factor": 1.0,
|
| 65 |
-
"initializer_range": 0.02,
|
| 66 |
-
"intermediate_size": 4096,
|
| 67 |
-
"is_decoder": false,
|
| 68 |
-
"is_encoder_decoder": false,
|
| 69 |
-
"label2id": {
|
| 70 |
-
"LABEL_0": 0,
|
| 71 |
-
"LABEL_1": 1
|
| 72 |
-
},
|
| 73 |
-
"layer_norm_eps": 1e-05,
|
| 74 |
-
"length_penalty": 1.0,
|
| 75 |
-
"max_length": 20,
|
| 76 |
-
"min_length": 0,
|
| 77 |
-
"model_type": "clip_vision_model",
|
| 78 |
-
"no_repeat_ngram_size": 0,
|
| 79 |
-
"num_attention_heads": 16,
|
| 80 |
-
"num_beam_groups": 1,
|
| 81 |
-
"num_beams": 1,
|
| 82 |
-
"num_channels": 3,
|
| 83 |
-
"num_hidden_layers": 24,
|
| 84 |
-
"num_return_sequences": 1,
|
| 85 |
-
"output_attentions": false,
|
| 86 |
-
"output_hidden_states": false,
|
| 87 |
-
"output_scores": false,
|
| 88 |
-
"pad_token_id": null,
|
| 89 |
-
"patch_size": 14,
|
| 90 |
-
"prefix": null,
|
| 91 |
-
"problem_type": null,
|
| 92 |
-
"projection_dim": 512,
|
| 93 |
-
"pruned_heads": {},
|
| 94 |
-
"remove_invalid_values": false,
|
| 95 |
-
"repetition_penalty": 1.0,
|
| 96 |
-
"return_dict": true,
|
| 97 |
-
"return_dict_in_generate": false,
|
| 98 |
-
"sep_token_id": null,
|
| 99 |
-
"suppress_tokens": null,
|
| 100 |
-
"task_specific_params": null,
|
| 101 |
-
"temperature": 1.0,
|
| 102 |
-
"tf_legacy_loss": false,
|
| 103 |
-
"tie_encoder_decoder": false,
|
| 104 |
-
"tie_word_embeddings": true,
|
| 105 |
-
"tokenizer_class": null,
|
| 106 |
-
"top_k": 50,
|
| 107 |
-
"top_p": 1.0,
|
| 108 |
-
"torch_dtype": null,
|
| 109 |
-
"torchscript": false,
|
| 110 |
-
"transformers_version": "4.30.1",
|
| 111 |
-
"typical_p": 1.0,
|
| 112 |
-
"use_bfloat16": false
|
| 113 |
-
}
|
| 114 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-llama2-chat-7B.json
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoForConditionalGeneration"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 4,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"_name_or_path": "meta-llama/Llama-2-7b-chat-hf",
|
| 10 |
-
"architectures": [
|
| 11 |
-
"LlamaForCausalLM"
|
| 12 |
-
],
|
| 13 |
-
"bos_token_id": 1,
|
| 14 |
-
"eos_token_id": 2,
|
| 15 |
-
"hidden_act": "silu",
|
| 16 |
-
"hidden_size": 4096,
|
| 17 |
-
"initializer_range": 0.02,
|
| 18 |
-
"intermediate_size": 11008,
|
| 19 |
-
"max_length": 4096,
|
| 20 |
-
"max_position_embeddings": 2048,
|
| 21 |
-
"model_type": "llama",
|
| 22 |
-
"num_attention_heads": 32,
|
| 23 |
-
"num_hidden_layers": 32,
|
| 24 |
-
"num_key_value_heads": 32,
|
| 25 |
-
"pad_token_id": 0,
|
| 26 |
-
"pretraining_tp": 1,
|
| 27 |
-
"rms_norm_eps": 1e-05,
|
| 28 |
-
"rope_scaling": null,
|
| 29 |
-
"tie_word_embeddings": false,
|
| 30 |
-
"torch_dtype": "float16",
|
| 31 |
-
"transformers_version": "4.32.0.dev0",
|
| 32 |
-
"use_cache": true,
|
| 33 |
-
"vocab_size": 32000
|
| 34 |
-
},
|
| 35 |
-
"torch_dtype": "float32",
|
| 36 |
-
"transformers_version": null,
|
| 37 |
-
"use_media_placement_augmentation": true,
|
| 38 |
-
"vision_config": {
|
| 39 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 40 |
-
"add_cross_attention": false,
|
| 41 |
-
"architectures": null,
|
| 42 |
-
"attention_dropout": 0.0,
|
| 43 |
-
"bad_words_ids": null,
|
| 44 |
-
"begin_suppress_tokens": null,
|
| 45 |
-
"bos_token_id": null,
|
| 46 |
-
"chunk_size_feed_forward": 0,
|
| 47 |
-
"cross_attention_hidden_size": null,
|
| 48 |
-
"decoder_start_token_id": null,
|
| 49 |
-
"diversity_penalty": 0.0,
|
| 50 |
-
"do_sample": false,
|
| 51 |
-
"early_stopping": false,
|
| 52 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 53 |
-
"eos_token_id": null,
|
| 54 |
-
"exponential_decay_length_penalty": null,
|
| 55 |
-
"finetuning_task": null,
|
| 56 |
-
"forced_bos_token_id": null,
|
| 57 |
-
"forced_eos_token_id": null,
|
| 58 |
-
"hidden_act": "quick_gelu",
|
| 59 |
-
"hidden_size": 1024,
|
| 60 |
-
"id2label": {
|
| 61 |
-
"0": "LABEL_0",
|
| 62 |
-
"1": "LABEL_1"
|
| 63 |
-
},
|
| 64 |
-
"image_size": 224,
|
| 65 |
-
"initializer_factor": 1.0,
|
| 66 |
-
"initializer_range": 0.02,
|
| 67 |
-
"intermediate_size": 4096,
|
| 68 |
-
"is_decoder": false,
|
| 69 |
-
"is_encoder_decoder": false,
|
| 70 |
-
"label2id": {
|
| 71 |
-
"LABEL_0": 0,
|
| 72 |
-
"LABEL_1": 1
|
| 73 |
-
},
|
| 74 |
-
"layer_norm_eps": 1e-05,
|
| 75 |
-
"length_penalty": 1.0,
|
| 76 |
-
"max_length": 20,
|
| 77 |
-
"min_length": 0,
|
| 78 |
-
"model_type": "clip_vision_model",
|
| 79 |
-
"no_repeat_ngram_size": 0,
|
| 80 |
-
"num_attention_heads": 16,
|
| 81 |
-
"num_beam_groups": 1,
|
| 82 |
-
"num_beams": 1,
|
| 83 |
-
"num_channels": 3,
|
| 84 |
-
"num_hidden_layers": 24,
|
| 85 |
-
"num_return_sequences": 1,
|
| 86 |
-
"output_attentions": false,
|
| 87 |
-
"output_hidden_states": false,
|
| 88 |
-
"output_scores": false,
|
| 89 |
-
"pad_token_id": null,
|
| 90 |
-
"patch_size": 14,
|
| 91 |
-
"prefix": null,
|
| 92 |
-
"problem_type": null,
|
| 93 |
-
"projection_dim": 512,
|
| 94 |
-
"pruned_heads": {},
|
| 95 |
-
"remove_invalid_values": false,
|
| 96 |
-
"repetition_penalty": 1.0,
|
| 97 |
-
"return_dict": true,
|
| 98 |
-
"return_dict_in_generate": false,
|
| 99 |
-
"sep_token_id": null,
|
| 100 |
-
"suppress_tokens": null,
|
| 101 |
-
"task_specific_params": null,
|
| 102 |
-
"temperature": 1.0,
|
| 103 |
-
"tf_legacy_loss": false,
|
| 104 |
-
"tie_encoder_decoder": false,
|
| 105 |
-
"tie_word_embeddings": true,
|
| 106 |
-
"tokenizer_class": null,
|
| 107 |
-
"top_k": 50,
|
| 108 |
-
"top_p": 1.0,
|
| 109 |
-
"torch_dtype": null,
|
| 110 |
-
"torchscript": false,
|
| 111 |
-
"transformers_version": "4.30.1",
|
| 112 |
-
"typical_p": 1.0,
|
| 113 |
-
"use_bfloat16": false
|
| 114 |
-
}
|
| 115 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-mpt-1B-redpajama.json
DELETED
|
@@ -1,131 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoForConditionalGeneration"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 1,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"_name_or_path": "",
|
| 10 |
-
"alibi": true,
|
| 11 |
-
"alibi_bias_max": 8,
|
| 12 |
-
"architectures": [
|
| 13 |
-
"MosaicGPT"
|
| 14 |
-
],
|
| 15 |
-
"attn_clip_qkv": null,
|
| 16 |
-
"attn_impl": "torch",
|
| 17 |
-
"attn_pdrop": 0,
|
| 18 |
-
"attn_qk_ln": true,
|
| 19 |
-
"attn_uses_sequence_id": false,
|
| 20 |
-
"d_model": 2048,
|
| 21 |
-
"hidden_size": 2048,
|
| 22 |
-
"emb_init_std": null,
|
| 23 |
-
"emb_init_uniform_lim": null,
|
| 24 |
-
"emb_pdrop": 0,
|
| 25 |
-
"embedding_fraction": 1.0,
|
| 26 |
-
"fan_mode": "fan_in",
|
| 27 |
-
"init_device": "cpu",
|
| 28 |
-
"init_div_is_residual": true,
|
| 29 |
-
"init_gain": 0,
|
| 30 |
-
"init_nonlinearity": "relu",
|
| 31 |
-
"init_std": 0.02,
|
| 32 |
-
"logit_scale": null,
|
| 33 |
-
"low_precision_layernorm": true,
|
| 34 |
-
"max_seq_len": 2048,
|
| 35 |
-
"mlp_ratio": 4,
|
| 36 |
-
"model_type": "mosaic_gpt",
|
| 37 |
-
"n_heads": 16,
|
| 38 |
-
"n_layers": 24,
|
| 39 |
-
"no_bias": true,
|
| 40 |
-
"param_init_fn": "kaiming_normal_",
|
| 41 |
-
"prefix_lm": false,
|
| 42 |
-
"resid_pdrop": 0,
|
| 43 |
-
"softmax_scale": null,
|
| 44 |
-
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
| 45 |
-
"torch_dtype": "float32",
|
| 46 |
-
"transformers_version": "4.27.4",
|
| 47 |
-
"use_cache": false,
|
| 48 |
-
"verbose": 0,
|
| 49 |
-
"vocab_size": 50432
|
| 50 |
-
},
|
| 51 |
-
"torch_dtype": "float32",
|
| 52 |
-
"transformers_version": null,
|
| 53 |
-
"use_media_placement_augmentation": true,
|
| 54 |
-
"vision_config": {
|
| 55 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 56 |
-
"add_cross_attention": false,
|
| 57 |
-
"architectures": null,
|
| 58 |
-
"attention_dropout": 0.0,
|
| 59 |
-
"bad_words_ids": null,
|
| 60 |
-
"begin_suppress_tokens": null,
|
| 61 |
-
"bos_token_id": null,
|
| 62 |
-
"chunk_size_feed_forward": 0,
|
| 63 |
-
"cross_attention_hidden_size": null,
|
| 64 |
-
"decoder_start_token_id": null,
|
| 65 |
-
"diversity_penalty": 0.0,
|
| 66 |
-
"do_sample": false,
|
| 67 |
-
"early_stopping": false,
|
| 68 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 69 |
-
"eos_token_id": null,
|
| 70 |
-
"exponential_decay_length_penalty": null,
|
| 71 |
-
"finetuning_task": null,
|
| 72 |
-
"forced_bos_token_id": null,
|
| 73 |
-
"forced_eos_token_id": null,
|
| 74 |
-
"hidden_act": "quick_gelu",
|
| 75 |
-
"hidden_size": 1024,
|
| 76 |
-
"id2label": {
|
| 77 |
-
"0": "LABEL_0",
|
| 78 |
-
"1": "LABEL_1"
|
| 79 |
-
},
|
| 80 |
-
"image_size": 224,
|
| 81 |
-
"initializer_factor": 1.0,
|
| 82 |
-
"initializer_range": 0.02,
|
| 83 |
-
"intermediate_size": 4096,
|
| 84 |
-
"is_decoder": false,
|
| 85 |
-
"is_encoder_decoder": false,
|
| 86 |
-
"label2id": {
|
| 87 |
-
"LABEL_0": 0,
|
| 88 |
-
"LABEL_1": 1
|
| 89 |
-
},
|
| 90 |
-
"layer_norm_eps": 1e-05,
|
| 91 |
-
"length_penalty": 1.0,
|
| 92 |
-
"max_length": 20,
|
| 93 |
-
"min_length": 0,
|
| 94 |
-
"model_type": "clip_vision_model",
|
| 95 |
-
"no_repeat_ngram_size": 0,
|
| 96 |
-
"num_attention_heads": 16,
|
| 97 |
-
"num_beam_groups": 1,
|
| 98 |
-
"num_beams": 1,
|
| 99 |
-
"num_channels": 3,
|
| 100 |
-
"num_hidden_layers": 24,
|
| 101 |
-
"num_return_sequences": 1,
|
| 102 |
-
"output_attentions": false,
|
| 103 |
-
"output_hidden_states": false,
|
| 104 |
-
"output_scores": false,
|
| 105 |
-
"pad_token_id": null,
|
| 106 |
-
"patch_size": 14,
|
| 107 |
-
"prefix": null,
|
| 108 |
-
"problem_type": null,
|
| 109 |
-
"projection_dim": 512,
|
| 110 |
-
"pruned_heads": {},
|
| 111 |
-
"remove_invalid_values": false,
|
| 112 |
-
"repetition_penalty": 1.0,
|
| 113 |
-
"return_dict": true,
|
| 114 |
-
"return_dict_in_generate": false,
|
| 115 |
-
"sep_token_id": null,
|
| 116 |
-
"suppress_tokens": null,
|
| 117 |
-
"task_specific_params": null,
|
| 118 |
-
"temperature": 1.0,
|
| 119 |
-
"tf_legacy_loss": false,
|
| 120 |
-
"tie_encoder_decoder": false,
|
| 121 |
-
"tie_word_embeddings": true,
|
| 122 |
-
"tokenizer_class": null,
|
| 123 |
-
"top_k": 50,
|
| 124 |
-
"top_p": 1.0,
|
| 125 |
-
"torch_dtype": null,
|
| 126 |
-
"torchscript": false,
|
| 127 |
-
"transformers_version": "4.30.1",
|
| 128 |
-
"typical_p": 1.0,
|
| 129 |
-
"use_bfloat16": false
|
| 130 |
-
}
|
| 131 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-mpt-30B-bf16.json
DELETED
|
@@ -1,195 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoForConditionalGeneration"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 7,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"_name_or_path": "",
|
| 10 |
-
"add_cross_attention": false,
|
| 11 |
-
"architectures": [
|
| 12 |
-
"MPTForCausalLM"
|
| 13 |
-
],
|
| 14 |
-
"attn_config": {
|
| 15 |
-
"alibi": true,
|
| 16 |
-
"alibi_bias_max": 8,
|
| 17 |
-
"attn_impl": "torch",
|
| 18 |
-
"attn_pdrop": 0,
|
| 19 |
-
"attn_type": "multihead_attention",
|
| 20 |
-
"attn_uses_sequence_id": false,
|
| 21 |
-
"clip_qkv": null,
|
| 22 |
-
"prefix_lm": false,
|
| 23 |
-
"qk_ln": false,
|
| 24 |
-
"softmax_scale": null
|
| 25 |
-
},
|
| 26 |
-
"bad_words_ids": null,
|
| 27 |
-
"begin_suppress_tokens": null,
|
| 28 |
-
"bos_token_id": null,
|
| 29 |
-
"chunk_size_feed_forward": 0,
|
| 30 |
-
"cross_attention_hidden_size": null,
|
| 31 |
-
"d_model": 7168,
|
| 32 |
-
"decoder_start_token_id": null,
|
| 33 |
-
"diversity_penalty": 0.0,
|
| 34 |
-
"do_sample": false,
|
| 35 |
-
"early_stopping": false,
|
| 36 |
-
"emb_pdrop": 0,
|
| 37 |
-
"embedding_fraction": 1.0,
|
| 38 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 39 |
-
"eos_token_id": null,
|
| 40 |
-
"expansion_ratio": 4,
|
| 41 |
-
"exponential_decay_length_penalty": null,
|
| 42 |
-
"finetuning_task": null,
|
| 43 |
-
"forced_bos_token_id": null,
|
| 44 |
-
"forced_eos_token_id": null,
|
| 45 |
-
"hidden_size": 7168,
|
| 46 |
-
"id2label": {
|
| 47 |
-
"0": "LABEL_0",
|
| 48 |
-
"1": "LABEL_1"
|
| 49 |
-
},
|
| 50 |
-
"init_config": {
|
| 51 |
-
"emb_init_std": null,
|
| 52 |
-
"emb_init_uniform_lim": null,
|
| 53 |
-
"fan_mode": "fan_in",
|
| 54 |
-
"init_div_is_residual": true,
|
| 55 |
-
"init_gain": 0.0,
|
| 56 |
-
"init_nonlinearity": "relu",
|
| 57 |
-
"init_std": null,
|
| 58 |
-
"name": "kaiming_normal_",
|
| 59 |
-
"verbose": 0
|
| 60 |
-
},
|
| 61 |
-
"init_device": "cpu",
|
| 62 |
-
"is_decoder": false,
|
| 63 |
-
"is_encoder_decoder": false,
|
| 64 |
-
"label2id": {
|
| 65 |
-
"LABEL_0": 0,
|
| 66 |
-
"LABEL_1": 1
|
| 67 |
-
},
|
| 68 |
-
"learned_pos_emb": true,
|
| 69 |
-
"length_penalty": 1.0,
|
| 70 |
-
"logit_scale": null,
|
| 71 |
-
"max_length": 20,
|
| 72 |
-
"max_seq_len": 8192,
|
| 73 |
-
"min_length": 0,
|
| 74 |
-
"model_type": "mpt",
|
| 75 |
-
"n_heads": 64,
|
| 76 |
-
"n_layers": 48,
|
| 77 |
-
"no_bias": true,
|
| 78 |
-
"no_repeat_ngram_size": 0,
|
| 79 |
-
"norm_type": "low_precision_layernorm",
|
| 80 |
-
"num_beam_groups": 1,
|
| 81 |
-
"num_beams": 1,
|
| 82 |
-
"num_return_sequences": 1,
|
| 83 |
-
"output_attentions": false,
|
| 84 |
-
"output_hidden_states": false,
|
| 85 |
-
"output_scores": false,
|
| 86 |
-
"pad_token_id": null,
|
| 87 |
-
"prefix": null,
|
| 88 |
-
"problem_type": null,
|
| 89 |
-
"pruned_heads": {},
|
| 90 |
-
"remove_invalid_values": false,
|
| 91 |
-
"repetition_penalty": 1.0,
|
| 92 |
-
"resid_pdrop": 0,
|
| 93 |
-
"return_dict": true,
|
| 94 |
-
"return_dict_in_generate": false,
|
| 95 |
-
"sep_token_id": null,
|
| 96 |
-
"suppress_tokens": null,
|
| 97 |
-
"task_specific_params": null,
|
| 98 |
-
"temperature": 1.0,
|
| 99 |
-
"tf_legacy_loss": false,
|
| 100 |
-
"tie_encoder_decoder": false,
|
| 101 |
-
"tie_word_embeddings": true,
|
| 102 |
-
"tokenizer_class": null,
|
| 103 |
-
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
| 104 |
-
"top_k": 50,
|
| 105 |
-
"top_p": 1.0,
|
| 106 |
-
"torch_dtype": "bfloat16",
|
| 107 |
-
"torchscript": false,
|
| 108 |
-
"transformers_version": "4.30.1",
|
| 109 |
-
"typical_p": 1.0,
|
| 110 |
-
"use_bfloat16": false,
|
| 111 |
-
"use_cache": false,
|
| 112 |
-
"verbose": 0,
|
| 113 |
-
"vocab_size": 50432
|
| 114 |
-
},
|
| 115 |
-
"torch_dtype": "bfloat16",
|
| 116 |
-
"transformers_version": null,
|
| 117 |
-
"use_media_placement_augmentation": true,
|
| 118 |
-
"vision_config": {
|
| 119 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 120 |
-
"add_cross_attention": false,
|
| 121 |
-
"architectures": null,
|
| 122 |
-
"attention_dropout": 0.0,
|
| 123 |
-
"bad_words_ids": null,
|
| 124 |
-
"begin_suppress_tokens": null,
|
| 125 |
-
"bos_token_id": null,
|
| 126 |
-
"chunk_size_feed_forward": 0,
|
| 127 |
-
"cross_attention_hidden_size": null,
|
| 128 |
-
"decoder_start_token_id": null,
|
| 129 |
-
"diversity_penalty": 0.0,
|
| 130 |
-
"do_sample": false,
|
| 131 |
-
"early_stopping": false,
|
| 132 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 133 |
-
"eos_token_id": null,
|
| 134 |
-
"exponential_decay_length_penalty": null,
|
| 135 |
-
"finetuning_task": null,
|
| 136 |
-
"forced_bos_token_id": null,
|
| 137 |
-
"forced_eos_token_id": null,
|
| 138 |
-
"hidden_act": "quick_gelu",
|
| 139 |
-
"hidden_size": 1024,
|
| 140 |
-
"id2label": {
|
| 141 |
-
"0": "LABEL_0",
|
| 142 |
-
"1": "LABEL_1"
|
| 143 |
-
},
|
| 144 |
-
"image_size": 224,
|
| 145 |
-
"initializer_factor": 1.0,
|
| 146 |
-
"initializer_range": 0.02,
|
| 147 |
-
"intermediate_size": 4096,
|
| 148 |
-
"is_decoder": false,
|
| 149 |
-
"is_encoder_decoder": false,
|
| 150 |
-
"label2id": {
|
| 151 |
-
"LABEL_0": 0,
|
| 152 |
-
"LABEL_1": 1
|
| 153 |
-
},
|
| 154 |
-
"layer_norm_eps": 1e-05,
|
| 155 |
-
"length_penalty": 1.0,
|
| 156 |
-
"max_length": 20,
|
| 157 |
-
"min_length": 0,
|
| 158 |
-
"model_type": "clip_vision_model",
|
| 159 |
-
"no_repeat_ngram_size": 0,
|
| 160 |
-
"num_attention_heads": 16,
|
| 161 |
-
"num_beam_groups": 1,
|
| 162 |
-
"num_beams": 1,
|
| 163 |
-
"num_channels": 3,
|
| 164 |
-
"num_hidden_layers": 24,
|
| 165 |
-
"num_return_sequences": 1,
|
| 166 |
-
"output_attentions": false,
|
| 167 |
-
"output_hidden_states": false,
|
| 168 |
-
"output_scores": false,
|
| 169 |
-
"pad_token_id": null,
|
| 170 |
-
"patch_size": 14,
|
| 171 |
-
"prefix": null,
|
| 172 |
-
"problem_type": null,
|
| 173 |
-
"projection_dim": 512,
|
| 174 |
-
"pruned_heads": {},
|
| 175 |
-
"remove_invalid_values": false,
|
| 176 |
-
"repetition_penalty": 1.0,
|
| 177 |
-
"return_dict": true,
|
| 178 |
-
"return_dict_in_generate": false,
|
| 179 |
-
"sep_token_id": null,
|
| 180 |
-
"suppress_tokens": null,
|
| 181 |
-
"task_specific_params": null,
|
| 182 |
-
"temperature": 1.0,
|
| 183 |
-
"tf_legacy_loss": false,
|
| 184 |
-
"tie_encoder_decoder": false,
|
| 185 |
-
"tie_word_embeddings": true,
|
| 186 |
-
"tokenizer_class": null,
|
| 187 |
-
"top_k": 50,
|
| 188 |
-
"top_p": 1.0,
|
| 189 |
-
"torch_dtype": null,
|
| 190 |
-
"torchscript": false,
|
| 191 |
-
"transformers_version": "4.30.1",
|
| 192 |
-
"typical_p": 1.0,
|
| 193 |
-
"use_bfloat16": false
|
| 194 |
-
}
|
| 195 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-mpt-30B.json
DELETED
|
@@ -1,195 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoForConditionalGeneration"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 7,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"_name_or_path": "",
|
| 10 |
-
"add_cross_attention": false,
|
| 11 |
-
"architectures": [
|
| 12 |
-
"MPTForCausalLM"
|
| 13 |
-
],
|
| 14 |
-
"attn_config": {
|
| 15 |
-
"alibi": true,
|
| 16 |
-
"alibi_bias_max": 8,
|
| 17 |
-
"attn_impl": "torch",
|
| 18 |
-
"attn_pdrop": 0,
|
| 19 |
-
"attn_type": "multihead_attention",
|
| 20 |
-
"attn_uses_sequence_id": false,
|
| 21 |
-
"clip_qkv": null,
|
| 22 |
-
"prefix_lm": false,
|
| 23 |
-
"qk_ln": false,
|
| 24 |
-
"softmax_scale": null
|
| 25 |
-
},
|
| 26 |
-
"bad_words_ids": null,
|
| 27 |
-
"begin_suppress_tokens": null,
|
| 28 |
-
"bos_token_id": null,
|
| 29 |
-
"chunk_size_feed_forward": 0,
|
| 30 |
-
"cross_attention_hidden_size": null,
|
| 31 |
-
"d_model": 7168,
|
| 32 |
-
"decoder_start_token_id": null,
|
| 33 |
-
"diversity_penalty": 0.0,
|
| 34 |
-
"do_sample": false,
|
| 35 |
-
"early_stopping": false,
|
| 36 |
-
"emb_pdrop": 0,
|
| 37 |
-
"embedding_fraction": 1.0,
|
| 38 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 39 |
-
"eos_token_id": null,
|
| 40 |
-
"expansion_ratio": 4,
|
| 41 |
-
"exponential_decay_length_penalty": null,
|
| 42 |
-
"finetuning_task": null,
|
| 43 |
-
"forced_bos_token_id": null,
|
| 44 |
-
"forced_eos_token_id": null,
|
| 45 |
-
"hidden_size": 7168,
|
| 46 |
-
"id2label": {
|
| 47 |
-
"0": "LABEL_0",
|
| 48 |
-
"1": "LABEL_1"
|
| 49 |
-
},
|
| 50 |
-
"init_config": {
|
| 51 |
-
"emb_init_std": null,
|
| 52 |
-
"emb_init_uniform_lim": null,
|
| 53 |
-
"fan_mode": "fan_in",
|
| 54 |
-
"init_div_is_residual": true,
|
| 55 |
-
"init_gain": 0.0,
|
| 56 |
-
"init_nonlinearity": "relu",
|
| 57 |
-
"init_std": null,
|
| 58 |
-
"name": "kaiming_normal_",
|
| 59 |
-
"verbose": 0
|
| 60 |
-
},
|
| 61 |
-
"init_device": "cpu",
|
| 62 |
-
"is_decoder": false,
|
| 63 |
-
"is_encoder_decoder": false,
|
| 64 |
-
"label2id": {
|
| 65 |
-
"LABEL_0": 0,
|
| 66 |
-
"LABEL_1": 1
|
| 67 |
-
},
|
| 68 |
-
"learned_pos_emb": true,
|
| 69 |
-
"length_penalty": 1.0,
|
| 70 |
-
"logit_scale": null,
|
| 71 |
-
"max_length": 20,
|
| 72 |
-
"max_seq_len": 8192,
|
| 73 |
-
"min_length": 0,
|
| 74 |
-
"model_type": "mpt",
|
| 75 |
-
"n_heads": 64,
|
| 76 |
-
"n_layers": 48,
|
| 77 |
-
"no_bias": true,
|
| 78 |
-
"no_repeat_ngram_size": 0,
|
| 79 |
-
"norm_type": "low_precision_layernorm",
|
| 80 |
-
"num_beam_groups": 1,
|
| 81 |
-
"num_beams": 1,
|
| 82 |
-
"num_return_sequences": 1,
|
| 83 |
-
"output_attentions": false,
|
| 84 |
-
"output_hidden_states": false,
|
| 85 |
-
"output_scores": false,
|
| 86 |
-
"pad_token_id": null,
|
| 87 |
-
"prefix": null,
|
| 88 |
-
"problem_type": null,
|
| 89 |
-
"pruned_heads": {},
|
| 90 |
-
"remove_invalid_values": false,
|
| 91 |
-
"repetition_penalty": 1.0,
|
| 92 |
-
"resid_pdrop": 0,
|
| 93 |
-
"return_dict": true,
|
| 94 |
-
"return_dict_in_generate": false,
|
| 95 |
-
"sep_token_id": null,
|
| 96 |
-
"suppress_tokens": null,
|
| 97 |
-
"task_specific_params": null,
|
| 98 |
-
"temperature": 1.0,
|
| 99 |
-
"tf_legacy_loss": false,
|
| 100 |
-
"tie_encoder_decoder": false,
|
| 101 |
-
"tie_word_embeddings": true,
|
| 102 |
-
"tokenizer_class": null,
|
| 103 |
-
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
| 104 |
-
"top_k": 50,
|
| 105 |
-
"top_p": 1.0,
|
| 106 |
-
"torch_dtype": "bfloat16",
|
| 107 |
-
"torchscript": false,
|
| 108 |
-
"transformers_version": "4.30.1",
|
| 109 |
-
"typical_p": 1.0,
|
| 110 |
-
"use_bfloat16": false,
|
| 111 |
-
"use_cache": false,
|
| 112 |
-
"verbose": 0,
|
| 113 |
-
"vocab_size": 50432
|
| 114 |
-
},
|
| 115 |
-
"torch_dtype": "float32",
|
| 116 |
-
"transformers_version": null,
|
| 117 |
-
"use_media_placement_augmentation": true,
|
| 118 |
-
"vision_config": {
|
| 119 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 120 |
-
"add_cross_attention": false,
|
| 121 |
-
"architectures": null,
|
| 122 |
-
"attention_dropout": 0.0,
|
| 123 |
-
"bad_words_ids": null,
|
| 124 |
-
"begin_suppress_tokens": null,
|
| 125 |
-
"bos_token_id": null,
|
| 126 |
-
"chunk_size_feed_forward": 0,
|
| 127 |
-
"cross_attention_hidden_size": null,
|
| 128 |
-
"decoder_start_token_id": null,
|
| 129 |
-
"diversity_penalty": 0.0,
|
| 130 |
-
"do_sample": false,
|
| 131 |
-
"early_stopping": false,
|
| 132 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 133 |
-
"eos_token_id": null,
|
| 134 |
-
"exponential_decay_length_penalty": null,
|
| 135 |
-
"finetuning_task": null,
|
| 136 |
-
"forced_bos_token_id": null,
|
| 137 |
-
"forced_eos_token_id": null,
|
| 138 |
-
"hidden_act": "quick_gelu",
|
| 139 |
-
"hidden_size": 1024,
|
| 140 |
-
"id2label": {
|
| 141 |
-
"0": "LABEL_0",
|
| 142 |
-
"1": "LABEL_1"
|
| 143 |
-
},
|
| 144 |
-
"image_size": 224,
|
| 145 |
-
"initializer_factor": 1.0,
|
| 146 |
-
"initializer_range": 0.02,
|
| 147 |
-
"intermediate_size": 4096,
|
| 148 |
-
"is_decoder": false,
|
| 149 |
-
"is_encoder_decoder": false,
|
| 150 |
-
"label2id": {
|
| 151 |
-
"LABEL_0": 0,
|
| 152 |
-
"LABEL_1": 1
|
| 153 |
-
},
|
| 154 |
-
"layer_norm_eps": 1e-05,
|
| 155 |
-
"length_penalty": 1.0,
|
| 156 |
-
"max_length": 20,
|
| 157 |
-
"min_length": 0,
|
| 158 |
-
"model_type": "clip_vision_model",
|
| 159 |
-
"no_repeat_ngram_size": 0,
|
| 160 |
-
"num_attention_heads": 16,
|
| 161 |
-
"num_beam_groups": 1,
|
| 162 |
-
"num_beams": 1,
|
| 163 |
-
"num_channels": 3,
|
| 164 |
-
"num_hidden_layers": 24,
|
| 165 |
-
"num_return_sequences": 1,
|
| 166 |
-
"output_attentions": false,
|
| 167 |
-
"output_hidden_states": false,
|
| 168 |
-
"output_scores": false,
|
| 169 |
-
"pad_token_id": null,
|
| 170 |
-
"patch_size": 14,
|
| 171 |
-
"prefix": null,
|
| 172 |
-
"problem_type": null,
|
| 173 |
-
"projection_dim": 512,
|
| 174 |
-
"pruned_heads": {},
|
| 175 |
-
"remove_invalid_values": false,
|
| 176 |
-
"repetition_penalty": 1.0,
|
| 177 |
-
"return_dict": true,
|
| 178 |
-
"return_dict_in_generate": false,
|
| 179 |
-
"sep_token_id": null,
|
| 180 |
-
"suppress_tokens": null,
|
| 181 |
-
"task_specific_params": null,
|
| 182 |
-
"temperature": 1.0,
|
| 183 |
-
"tf_legacy_loss": false,
|
| 184 |
-
"tie_encoder_decoder": false,
|
| 185 |
-
"tie_word_embeddings": true,
|
| 186 |
-
"tokenizer_class": null,
|
| 187 |
-
"top_k": 50,
|
| 188 |
-
"top_p": 1.0,
|
| 189 |
-
"torch_dtype": null,
|
| 190 |
-
"torchscript": false,
|
| 191 |
-
"transformers_version": "4.30.1",
|
| 192 |
-
"typical_p": 1.0,
|
| 193 |
-
"use_bfloat16": false
|
| 194 |
-
}
|
| 195 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-mpt-7B.json
DELETED
|
@@ -1,195 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoForConditionalGeneration"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 4,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"_name_or_path": "",
|
| 10 |
-
"add_cross_attention": false,
|
| 11 |
-
"architectures": [
|
| 12 |
-
"MPTForCausalLM"
|
| 13 |
-
],
|
| 14 |
-
"attn_config": {
|
| 15 |
-
"alibi": true,
|
| 16 |
-
"alibi_bias_max": 8,
|
| 17 |
-
"attn_impl": "torch",
|
| 18 |
-
"attn_pdrop": 0,
|
| 19 |
-
"attn_type": "multihead_attention",
|
| 20 |
-
"attn_uses_sequence_id": false,
|
| 21 |
-
"clip_qkv": null,
|
| 22 |
-
"prefix_lm": false,
|
| 23 |
-
"qk_ln": false,
|
| 24 |
-
"softmax_scale": null
|
| 25 |
-
},
|
| 26 |
-
"bad_words_ids": null,
|
| 27 |
-
"begin_suppress_tokens": null,
|
| 28 |
-
"bos_token_id": null,
|
| 29 |
-
"chunk_size_feed_forward": 0,
|
| 30 |
-
"cross_attention_hidden_size": null,
|
| 31 |
-
"d_model": 4096,
|
| 32 |
-
"decoder_start_token_id": null,
|
| 33 |
-
"diversity_penalty": 0.0,
|
| 34 |
-
"do_sample": false,
|
| 35 |
-
"early_stopping": false,
|
| 36 |
-
"emb_pdrop": 0,
|
| 37 |
-
"embedding_fraction": 1.0,
|
| 38 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 39 |
-
"eos_token_id": null,
|
| 40 |
-
"expansion_ratio": 4,
|
| 41 |
-
"exponential_decay_length_penalty": null,
|
| 42 |
-
"finetuning_task": null,
|
| 43 |
-
"forced_bos_token_id": null,
|
| 44 |
-
"forced_eos_token_id": null,
|
| 45 |
-
"hidden_size": 4096,
|
| 46 |
-
"id2label": {
|
| 47 |
-
"0": "LABEL_0",
|
| 48 |
-
"1": "LABEL_1"
|
| 49 |
-
},
|
| 50 |
-
"init_config": {
|
| 51 |
-
"emb_init_std": null,
|
| 52 |
-
"emb_init_uniform_lim": null,
|
| 53 |
-
"fan_mode": "fan_in",
|
| 54 |
-
"init_div_is_residual": true,
|
| 55 |
-
"init_gain": 0,
|
| 56 |
-
"init_nonlinearity": "relu",
|
| 57 |
-
"init_std": 0.02,
|
| 58 |
-
"name": "kaiming_normal_",
|
| 59 |
-
"verbose": 0
|
| 60 |
-
},
|
| 61 |
-
"init_device": "cpu",
|
| 62 |
-
"is_decoder": false,
|
| 63 |
-
"is_encoder_decoder": false,
|
| 64 |
-
"label2id": {
|
| 65 |
-
"LABEL_0": 0,
|
| 66 |
-
"LABEL_1": 1
|
| 67 |
-
},
|
| 68 |
-
"learned_pos_emb": true,
|
| 69 |
-
"length_penalty": 1.0,
|
| 70 |
-
"logit_scale": null,
|
| 71 |
-
"max_length": 20,
|
| 72 |
-
"max_seq_len": 2048,
|
| 73 |
-
"min_length": 0,
|
| 74 |
-
"model_type": "mpt",
|
| 75 |
-
"n_heads": 32,
|
| 76 |
-
"n_layers": 32,
|
| 77 |
-
"no_bias": true,
|
| 78 |
-
"no_repeat_ngram_size": 0,
|
| 79 |
-
"norm_type": "low_precision_layernorm",
|
| 80 |
-
"num_beam_groups": 1,
|
| 81 |
-
"num_beams": 1,
|
| 82 |
-
"num_return_sequences": 1,
|
| 83 |
-
"output_attentions": false,
|
| 84 |
-
"output_hidden_states": false,
|
| 85 |
-
"output_scores": false,
|
| 86 |
-
"pad_token_id": null,
|
| 87 |
-
"prefix": null,
|
| 88 |
-
"problem_type": null,
|
| 89 |
-
"pruned_heads": {},
|
| 90 |
-
"remove_invalid_values": false,
|
| 91 |
-
"repetition_penalty": 1.0,
|
| 92 |
-
"resid_pdrop": 0,
|
| 93 |
-
"return_dict": true,
|
| 94 |
-
"return_dict_in_generate": false,
|
| 95 |
-
"sep_token_id": null,
|
| 96 |
-
"suppress_tokens": null,
|
| 97 |
-
"task_specific_params": null,
|
| 98 |
-
"temperature": 1.0,
|
| 99 |
-
"tf_legacy_loss": false,
|
| 100 |
-
"tie_encoder_decoder": false,
|
| 101 |
-
"tie_word_embeddings": true,
|
| 102 |
-
"tokenizer_class": null,
|
| 103 |
-
"tokenizer_name": "EleutherAI/gpt-neox-20b",
|
| 104 |
-
"top_k": 50,
|
| 105 |
-
"top_p": 1.0,
|
| 106 |
-
"torch_dtype": "bfloat16",
|
| 107 |
-
"torchscript": false,
|
| 108 |
-
"transformers_version": "4.30.1",
|
| 109 |
-
"typical_p": 1.0,
|
| 110 |
-
"use_bfloat16": false,
|
| 111 |
-
"use_cache": false,
|
| 112 |
-
"verbose": 0,
|
| 113 |
-
"vocab_size": 50432
|
| 114 |
-
},
|
| 115 |
-
"torch_dtype": "float32",
|
| 116 |
-
"transformers_version": null,
|
| 117 |
-
"use_media_placement_augmentation": true,
|
| 118 |
-
"vision_config": {
|
| 119 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 120 |
-
"add_cross_attention": false,
|
| 121 |
-
"architectures": null,
|
| 122 |
-
"attention_dropout": 0.0,
|
| 123 |
-
"bad_words_ids": null,
|
| 124 |
-
"begin_suppress_tokens": null,
|
| 125 |
-
"bos_token_id": null,
|
| 126 |
-
"chunk_size_feed_forward": 0,
|
| 127 |
-
"cross_attention_hidden_size": null,
|
| 128 |
-
"decoder_start_token_id": null,
|
| 129 |
-
"diversity_penalty": 0.0,
|
| 130 |
-
"do_sample": false,
|
| 131 |
-
"early_stopping": false,
|
| 132 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 133 |
-
"eos_token_id": null,
|
| 134 |
-
"exponential_decay_length_penalty": null,
|
| 135 |
-
"finetuning_task": null,
|
| 136 |
-
"forced_bos_token_id": null,
|
| 137 |
-
"forced_eos_token_id": null,
|
| 138 |
-
"hidden_act": "quick_gelu",
|
| 139 |
-
"hidden_size": 1024,
|
| 140 |
-
"id2label": {
|
| 141 |
-
"0": "LABEL_0",
|
| 142 |
-
"1": "LABEL_1"
|
| 143 |
-
},
|
| 144 |
-
"image_size": 224,
|
| 145 |
-
"initializer_factor": 1.0,
|
| 146 |
-
"initializer_range": 0.02,
|
| 147 |
-
"intermediate_size": 4096,
|
| 148 |
-
"is_decoder": false,
|
| 149 |
-
"is_encoder_decoder": false,
|
| 150 |
-
"label2id": {
|
| 151 |
-
"LABEL_0": 0,
|
| 152 |
-
"LABEL_1": 1
|
| 153 |
-
},
|
| 154 |
-
"layer_norm_eps": 1e-05,
|
| 155 |
-
"length_penalty": 1.0,
|
| 156 |
-
"max_length": 20,
|
| 157 |
-
"min_length": 0,
|
| 158 |
-
"model_type": "clip_vision_model",
|
| 159 |
-
"no_repeat_ngram_size": 0,
|
| 160 |
-
"num_attention_heads": 16,
|
| 161 |
-
"num_beam_groups": 1,
|
| 162 |
-
"num_beams": 1,
|
| 163 |
-
"num_channels": 3,
|
| 164 |
-
"num_hidden_layers": 24,
|
| 165 |
-
"num_return_sequences": 1,
|
| 166 |
-
"output_attentions": false,
|
| 167 |
-
"output_hidden_states": false,
|
| 168 |
-
"output_scores": false,
|
| 169 |
-
"pad_token_id": null,
|
| 170 |
-
"patch_size": 14,
|
| 171 |
-
"prefix": null,
|
| 172 |
-
"problem_type": null,
|
| 173 |
-
"projection_dim": 512,
|
| 174 |
-
"pruned_heads": {},
|
| 175 |
-
"remove_invalid_values": false,
|
| 176 |
-
"repetition_penalty": 1.0,
|
| 177 |
-
"return_dict": true,
|
| 178 |
-
"return_dict_in_generate": false,
|
| 179 |
-
"sep_token_id": null,
|
| 180 |
-
"suppress_tokens": null,
|
| 181 |
-
"task_specific_params": null,
|
| 182 |
-
"temperature": 1.0,
|
| 183 |
-
"tf_legacy_loss": false,
|
| 184 |
-
"tie_encoder_decoder": false,
|
| 185 |
-
"tie_word_embeddings": true,
|
| 186 |
-
"tokenizer_class": null,
|
| 187 |
-
"top_k": 50,
|
| 188 |
-
"top_p": 1.0,
|
| 189 |
-
"torch_dtype": null,
|
| 190 |
-
"torchscript": false,
|
| 191 |
-
"transformers_version": "4.30.1",
|
| 192 |
-
"typical_p": 1.0,
|
| 193 |
-
"use_bfloat16": false
|
| 194 |
-
}
|
| 195 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-vicuna-33B-v1.3.json
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoForConditionalGeneration"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 4,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"_name_or_path": "/home/luodian/projects/checkpoints/vicuna-33b-v1.3",
|
| 10 |
-
"architectures": [
|
| 11 |
-
"LlamaForCausalLM"
|
| 12 |
-
],
|
| 13 |
-
"bos_token_id": 1,
|
| 14 |
-
"eos_token_id": 2,
|
| 15 |
-
"hidden_act": "silu",
|
| 16 |
-
"hidden_size": 6656,
|
| 17 |
-
"initializer_range": 0.02,
|
| 18 |
-
"intermediate_size": 17920,
|
| 19 |
-
"max_position_embeddings": 2048,
|
| 20 |
-
"model_type": "llama",
|
| 21 |
-
"num_attention_heads": 52,
|
| 22 |
-
"num_hidden_layers": 60,
|
| 23 |
-
"pad_token_id": 0,
|
| 24 |
-
"rms_norm_eps": 1e-06,
|
| 25 |
-
"tie_word_embeddings": false,
|
| 26 |
-
"torch_dtype": "float16",
|
| 27 |
-
"transformers_version": "4.28.1",
|
| 28 |
-
"use_cache": false,
|
| 29 |
-
"vocab_size": 32000
|
| 30 |
-
},
|
| 31 |
-
"torch_dtype": "float32",
|
| 32 |
-
"transformers_version": null,
|
| 33 |
-
"use_media_placement_augmentation": true,
|
| 34 |
-
"vision_config": {
|
| 35 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 36 |
-
"add_cross_attention": false,
|
| 37 |
-
"architectures": null,
|
| 38 |
-
"attention_dropout": 0.0,
|
| 39 |
-
"bad_words_ids": null,
|
| 40 |
-
"begin_suppress_tokens": null,
|
| 41 |
-
"bos_token_id": null,
|
| 42 |
-
"chunk_size_feed_forward": 0,
|
| 43 |
-
"cross_attention_hidden_size": null,
|
| 44 |
-
"decoder_start_token_id": null,
|
| 45 |
-
"diversity_penalty": 0.0,
|
| 46 |
-
"do_sample": false,
|
| 47 |
-
"early_stopping": false,
|
| 48 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 49 |
-
"eos_token_id": null,
|
| 50 |
-
"exponential_decay_length_penalty": null,
|
| 51 |
-
"finetuning_task": null,
|
| 52 |
-
"forced_bos_token_id": null,
|
| 53 |
-
"forced_eos_token_id": null,
|
| 54 |
-
"hidden_act": "quick_gelu",
|
| 55 |
-
"hidden_size": 1024,
|
| 56 |
-
"id2label": {
|
| 57 |
-
"0": "LABEL_0",
|
| 58 |
-
"1": "LABEL_1"
|
| 59 |
-
},
|
| 60 |
-
"image_size": 224,
|
| 61 |
-
"initializer_factor": 1.0,
|
| 62 |
-
"initializer_range": 0.02,
|
| 63 |
-
"intermediate_size": 4096,
|
| 64 |
-
"is_decoder": false,
|
| 65 |
-
"is_encoder_decoder": false,
|
| 66 |
-
"label2id": {
|
| 67 |
-
"LABEL_0": 0,
|
| 68 |
-
"LABEL_1": 1
|
| 69 |
-
},
|
| 70 |
-
"layer_norm_eps": 1e-05,
|
| 71 |
-
"length_penalty": 1.0,
|
| 72 |
-
"max_length": 20,
|
| 73 |
-
"min_length": 0,
|
| 74 |
-
"model_type": "clip_vision_model",
|
| 75 |
-
"no_repeat_ngram_size": 0,
|
| 76 |
-
"num_attention_heads": 16,
|
| 77 |
-
"num_beam_groups": 1,
|
| 78 |
-
"num_beams": 1,
|
| 79 |
-
"num_channels": 3,
|
| 80 |
-
"num_hidden_layers": 24,
|
| 81 |
-
"num_return_sequences": 1,
|
| 82 |
-
"output_attentions": false,
|
| 83 |
-
"output_hidden_states": false,
|
| 84 |
-
"output_scores": false,
|
| 85 |
-
"pad_token_id": null,
|
| 86 |
-
"patch_size": 14,
|
| 87 |
-
"prefix": null,
|
| 88 |
-
"problem_type": null,
|
| 89 |
-
"projection_dim": 512,
|
| 90 |
-
"pruned_heads": {},
|
| 91 |
-
"remove_invalid_values": false,
|
| 92 |
-
"repetition_penalty": 1.0,
|
| 93 |
-
"return_dict": true,
|
| 94 |
-
"return_dict_in_generate": false,
|
| 95 |
-
"sep_token_id": null,
|
| 96 |
-
"suppress_tokens": null,
|
| 97 |
-
"task_specific_params": null,
|
| 98 |
-
"temperature": 1.0,
|
| 99 |
-
"tf_legacy_loss": false,
|
| 100 |
-
"tie_encoder_decoder": false,
|
| 101 |
-
"tie_word_embeddings": true,
|
| 102 |
-
"tokenizer_class": null,
|
| 103 |
-
"top_k": 50,
|
| 104 |
-
"top_p": 1.0,
|
| 105 |
-
"torch_dtype": null,
|
| 106 |
-
"torchscript": false,
|
| 107 |
-
"transformers_version": "4.30.1",
|
| 108 |
-
"typical_p": 1.0,
|
| 109 |
-
"use_bfloat16": false
|
| 110 |
-
}
|
| 111 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/flamingo-vicuna-7B-v1.3.json
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"_commit_hash": null,
|
| 3 |
-
"architectures": [
|
| 4 |
-
"FlamingoForConditionalGeneration"
|
| 5 |
-
],
|
| 6 |
-
"cross_attn_every_n_layers": 4,
|
| 7 |
-
"model_type": "flamingo",
|
| 8 |
-
"text_config": {
|
| 9 |
-
"_name_or_path": "/mnt/petrelfs/share_data/zhangyuanhan/vicuna-7b-v1.3",
|
| 10 |
-
"architectures": [
|
| 11 |
-
"LlamaForCausalLM"
|
| 12 |
-
],
|
| 13 |
-
"bos_token_id": 1,
|
| 14 |
-
"eos_token_id": 2,
|
| 15 |
-
"hidden_act": "silu",
|
| 16 |
-
"hidden_size": 4096,
|
| 17 |
-
"initializer_range": 0.02,
|
| 18 |
-
"intermediate_size": 11008,
|
| 19 |
-
"max_position_embeddings": 2048,
|
| 20 |
-
"model_type": "llama",
|
| 21 |
-
"num_attention_heads": 32,
|
| 22 |
-
"num_hidden_layers": 32,
|
| 23 |
-
"pad_token_id": 0,
|
| 24 |
-
"rms_norm_eps": 1e-06,
|
| 25 |
-
"tie_word_embeddings": false,
|
| 26 |
-
"torch_dtype": "float16",
|
| 27 |
-
"transformers_version": "4.28.1",
|
| 28 |
-
"use_cache": false,
|
| 29 |
-
"vocab_size": 32000
|
| 30 |
-
},
|
| 31 |
-
"torch_dtype": "float32",
|
| 32 |
-
"transformers_version": null,
|
| 33 |
-
"use_media_placement_augmentation": true,
|
| 34 |
-
"vision_config": {
|
| 35 |
-
"_name_or_path": "openai/clip-vit-large-patch14",
|
| 36 |
-
"add_cross_attention": false,
|
| 37 |
-
"architectures": null,
|
| 38 |
-
"attention_dropout": 0.0,
|
| 39 |
-
"bad_words_ids": null,
|
| 40 |
-
"begin_suppress_tokens": null,
|
| 41 |
-
"bos_token_id": null,
|
| 42 |
-
"chunk_size_feed_forward": 0,
|
| 43 |
-
"cross_attention_hidden_size": null,
|
| 44 |
-
"decoder_start_token_id": null,
|
| 45 |
-
"diversity_penalty": 0.0,
|
| 46 |
-
"do_sample": false,
|
| 47 |
-
"early_stopping": false,
|
| 48 |
-
"encoder_no_repeat_ngram_size": 0,
|
| 49 |
-
"eos_token_id": null,
|
| 50 |
-
"exponential_decay_length_penalty": null,
|
| 51 |
-
"finetuning_task": null,
|
| 52 |
-
"forced_bos_token_id": null,
|
| 53 |
-
"forced_eos_token_id": null,
|
| 54 |
-
"hidden_act": "quick_gelu",
|
| 55 |
-
"hidden_size": 1024,
|
| 56 |
-
"id2label": {
|
| 57 |
-
"0": "LABEL_0",
|
| 58 |
-
"1": "LABEL_1"
|
| 59 |
-
},
|
| 60 |
-
"image_size": 224,
|
| 61 |
-
"initializer_factor": 1.0,
|
| 62 |
-
"initializer_range": 0.02,
|
| 63 |
-
"intermediate_size": 4096,
|
| 64 |
-
"is_decoder": false,
|
| 65 |
-
"is_encoder_decoder": false,
|
| 66 |
-
"label2id": {
|
| 67 |
-
"LABEL_0": 0,
|
| 68 |
-
"LABEL_1": 1
|
| 69 |
-
},
|
| 70 |
-
"layer_norm_eps": 1e-05,
|
| 71 |
-
"length_penalty": 1.0,
|
| 72 |
-
"max_length": 20,
|
| 73 |
-
"min_length": 0,
|
| 74 |
-
"model_type": "clip_vision_model",
|
| 75 |
-
"no_repeat_ngram_size": 0,
|
| 76 |
-
"num_attention_heads": 16,
|
| 77 |
-
"num_beam_groups": 1,
|
| 78 |
-
"num_beams": 1,
|
| 79 |
-
"num_channels": 3,
|
| 80 |
-
"num_hidden_layers": 24,
|
| 81 |
-
"num_return_sequences": 1,
|
| 82 |
-
"output_attentions": false,
|
| 83 |
-
"output_hidden_states": false,
|
| 84 |
-
"output_scores": false,
|
| 85 |
-
"pad_token_id": null,
|
| 86 |
-
"patch_size": 14,
|
| 87 |
-
"prefix": null,
|
| 88 |
-
"problem_type": null,
|
| 89 |
-
"projection_dim": 512,
|
| 90 |
-
"pruned_heads": {},
|
| 91 |
-
"remove_invalid_values": false,
|
| 92 |
-
"repetition_penalty": 1.0,
|
| 93 |
-
"return_dict": true,
|
| 94 |
-
"return_dict_in_generate": false,
|
| 95 |
-
"sep_token_id": null,
|
| 96 |
-
"suppress_tokens": null,
|
| 97 |
-
"task_specific_params": null,
|
| 98 |
-
"temperature": 1.0,
|
| 99 |
-
"tf_legacy_loss": false,
|
| 100 |
-
"tie_encoder_decoder": false,
|
| 101 |
-
"tie_word_embeddings": true,
|
| 102 |
-
"tokenizer_class": null,
|
| 103 |
-
"top_k": 50,
|
| 104 |
-
"top_p": 1.0,
|
| 105 |
-
"torch_dtype": null,
|
| 106 |
-
"torchscript": false,
|
| 107 |
-
"transformers_version": "4.30.1",
|
| 108 |
-
"typical_p": 1.0,
|
| 109 |
-
"use_bfloat16": false
|
| 110 |
-
}
|
| 111 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/injecting_falcon_into_flamingo.py
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import torch
|
| 3 |
-
from .configuration_flamingo import FlamingoConfig
|
| 4 |
-
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 5 |
-
|
| 6 |
-
root_dir = os.environ["AZP"]
|
| 7 |
-
print(root_dir)
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
config = FlamingoConfig.from_json_file(".flamingo-falcon-7B.json")
|
| 11 |
-
model = FlamingoForConditionalGeneration(config=config)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
state_dict_files = [
|
| 15 |
-
f"{root_dir}/otter/checkpoints/falcon-7b/pytorch_model-00001-of-00002.bin",
|
| 16 |
-
f"{root_dir}/otter/checkpoints/falcon-7b/pytorch_model-00002-of-00002.bin",
|
| 17 |
-
]
|
| 18 |
-
|
| 19 |
-
state_dict = {}
|
| 20 |
-
for file in state_dict_files:
|
| 21 |
-
state_dict_part = torch.load(file, map_location="cpu")
|
| 22 |
-
state_dict.update(state_dict_part)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
state_dict_3 = torch.load("{root_dir}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 26 |
-
for cur_key in list(state_dict_3.keys()):
|
| 27 |
-
if "vision_encoder" not in cur_key:
|
| 28 |
-
del state_dict_3[cur_key]
|
| 29 |
-
|
| 30 |
-
_ = model.load_state_dict(
|
| 31 |
-
state_dict_3,
|
| 32 |
-
False,
|
| 33 |
-
)
|
| 34 |
-
print(_[1])
|
| 35 |
-
|
| 36 |
-
save_state_dict_1 = {}
|
| 37 |
-
for key in state_dict:
|
| 38 |
-
if ".h." in key:
|
| 39 |
-
_, _, layer_num, *remain_names = key.split(".")
|
| 40 |
-
target_key = f"transformer.h.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 41 |
-
else:
|
| 42 |
-
target_key = key
|
| 43 |
-
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 44 |
-
_ = model.lang_encoder.load_state_dict(
|
| 45 |
-
save_state_dict_1,
|
| 46 |
-
False,
|
| 47 |
-
)
|
| 48 |
-
print(_[1])
|
| 49 |
-
model.save_pretrained(f"{root_dir}/otter/checkpoints/flamingo-falcon-7b/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/injecting_llama2_into_flamingo.py
DELETED
|
@@ -1,95 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
|
| 7 |
-
import sys
|
| 8 |
-
|
| 9 |
-
from .configuration_flamingo import FlamingoConfig
|
| 10 |
-
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 11 |
-
|
| 12 |
-
# from .configuration_flamingo import FlamingoConfig
|
| 13 |
-
# from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 14 |
-
|
| 15 |
-
parser = argparse.ArgumentParser(description="Convert Vicuna model")
|
| 16 |
-
parser.add_argument("--model_choice", type=str, default="13B", help="Choose either '7B' or '13B'")
|
| 17 |
-
parser.add_argument("--llama2_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 18 |
-
parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 19 |
-
args = parser.parse_args()
|
| 20 |
-
|
| 21 |
-
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 22 |
-
|
| 23 |
-
root_dir = args.llama2_root_dir
|
| 24 |
-
model_choice = args.model_choice
|
| 25 |
-
save_root_dir = args.save_root_dir
|
| 26 |
-
|
| 27 |
-
# prepare vicuna model at first
|
| 28 |
-
# you can visit https://huggingface.co/lmsys/Llama-2-33b-chat-hf to download 7B and 30B instruct checkpoints.
|
| 29 |
-
if model_choice == "7B":
|
| 30 |
-
config_file = "./flamingo/flamingo-llama2-chat-7B.json"
|
| 31 |
-
state_dict_files = [
|
| 32 |
-
f"{root_dir}/Llama-2-7b-chat-hf/pytorch_model-00001-of-00002.bin",
|
| 33 |
-
f"{root_dir}/Llama-2-7b-chat-hf/pytorch_model-00002-of-00002.bin",
|
| 34 |
-
]
|
| 35 |
-
save_path = f"{save_root_dir}/flamingo-llama2-chat-7B-init"
|
| 36 |
-
elif model_choice == "13B":
|
| 37 |
-
config_file = "./flamingo/flamingo-llama2-chat-13B.json"
|
| 38 |
-
state_dict_files = [
|
| 39 |
-
f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00001-of-00003.bin",
|
| 40 |
-
f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00002-of-00003.bin",
|
| 41 |
-
f"{root_dir}/Llama-2-13b-chat-hf/pytorch_model-00003-of-00003.bin",
|
| 42 |
-
]
|
| 43 |
-
save_path = f"{save_root_dir}/flamingo-llama2-chat-13B-init"
|
| 44 |
-
else:
|
| 45 |
-
raise ValueError("Invalid model_choice. Choose either '13B' or '7B'.")
|
| 46 |
-
|
| 47 |
-
config = FlamingoConfig.from_json_file(config_file)
|
| 48 |
-
model = FlamingoForConditionalGeneration(config=config)
|
| 49 |
-
|
| 50 |
-
# load flamingo's vision encoder from last checkpoint.
|
| 51 |
-
# you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
|
| 52 |
-
# AZP = "os.environ["AZP"]"
|
| 53 |
-
AZP = os.environ["AZP"]
|
| 54 |
-
state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 55 |
-
for cur_key in list(state_dict_3.keys()):
|
| 56 |
-
if "vision_encoder" not in cur_key:
|
| 57 |
-
del state_dict_3[cur_key]
|
| 58 |
-
|
| 59 |
-
load_msg = model.load_state_dict(
|
| 60 |
-
state_dict_3,
|
| 61 |
-
False,
|
| 62 |
-
)
|
| 63 |
-
# print incompatible keys
|
| 64 |
-
print(load_msg[1])
|
| 65 |
-
|
| 66 |
-
# Loading vicuna weights
|
| 67 |
-
state_dict = {}
|
| 68 |
-
for file in tqdm(state_dict_files, desc="Loading state dict"):
|
| 69 |
-
state_dict_part = torch.load(file, map_location="cpu")
|
| 70 |
-
state_dict.update(state_dict_part)
|
| 71 |
-
|
| 72 |
-
save_state_dict_1 = {}
|
| 73 |
-
for key in state_dict:
|
| 74 |
-
if ".layers." in key:
|
| 75 |
-
_, _, layer_num, *remain_names = key.split(".")
|
| 76 |
-
target_key = f"model.layers.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 77 |
-
else:
|
| 78 |
-
target_key = key
|
| 79 |
-
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 80 |
-
|
| 81 |
-
# Reshape the token embedding to 50280 for compatible
|
| 82 |
-
model.lang_encoder.resize_token_embeddings(32000)
|
| 83 |
-
|
| 84 |
-
load_msg = model.lang_encoder.load_state_dict(
|
| 85 |
-
save_state_dict_1,
|
| 86 |
-
False,
|
| 87 |
-
)
|
| 88 |
-
# Reshape the token embedding to 32002 for compatible
|
| 89 |
-
model.lang_encoder.resize_token_embeddings(32002)
|
| 90 |
-
# print incompatible keys
|
| 91 |
-
print(load_msg[1])
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
print(f"Saving model to {save_path}...")
|
| 95 |
-
model.save_pretrained(save_path, max_shard_size="10GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/injecting_mpt-1B-redpajama_into_flamingo.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
|
| 7 |
-
import sys
|
| 8 |
-
|
| 9 |
-
from configuration_flamingo import FlamingoConfig
|
| 10 |
-
from modeling_flamingo import FlamingoForConditionalGeneration
|
| 11 |
-
from utils import rename_flamingo_checkpoint
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
parser = argparse.ArgumentParser(description="Convert MPT model")
|
| 15 |
-
parser.add_argument("--mpt_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 16 |
-
parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 17 |
-
parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
|
| 18 |
-
args = parser.parse_args()
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
root_dir = args.mpt_root_dir
|
| 22 |
-
save_root_dir = args.save_root_dir
|
| 23 |
-
|
| 24 |
-
# prepare mpt model at first
|
| 25 |
-
# you can visit https://huggingface.co/mosaicml to download 7B and 30B instruct checkpoints.
|
| 26 |
-
config_file = "./flamingo/flamingo-mpt-1B-redpajama.json"
|
| 27 |
-
state_dict_file = f"{root_dir}/pytorch_model.bin"
|
| 28 |
-
save_path = f"{save_root_dir}/flamingo-mpt-1b-redpajama-200b-dolly"
|
| 29 |
-
|
| 30 |
-
config = FlamingoConfig.from_json_file(config_file)
|
| 31 |
-
|
| 32 |
-
model = FlamingoForConditionalGeneration(config=config)
|
| 33 |
-
|
| 34 |
-
# Loading mpt weights
|
| 35 |
-
state_dict = torch.load(state_dict_file, map_location="cpu")
|
| 36 |
-
save_state_dict_1 = {}
|
| 37 |
-
for key in state_dict:
|
| 38 |
-
if ".blocks." in key:
|
| 39 |
-
_, _, layer_num, *remain_names = key.split(".")
|
| 40 |
-
target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 41 |
-
else:
|
| 42 |
-
target_key = key
|
| 43 |
-
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 44 |
-
|
| 45 |
-
load_msg = model.lang_encoder.load_state_dict(
|
| 46 |
-
save_state_dict_1,
|
| 47 |
-
False,
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
# load flamingo's vision encoder from last checkpoint.
|
| 51 |
-
# you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
|
| 52 |
-
AZP = os.environ["AZP"]
|
| 53 |
-
state_dict_3 = torch.load(f"{AZP}/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 54 |
-
for cur_key in list(state_dict_3.keys()):
|
| 55 |
-
if "vision_encoder" not in cur_key:
|
| 56 |
-
del state_dict_3[cur_key]
|
| 57 |
-
|
| 58 |
-
load_msg = model.load_state_dict(
|
| 59 |
-
state_dict_3,
|
| 60 |
-
False,
|
| 61 |
-
)
|
| 62 |
-
# print incompatible keys
|
| 63 |
-
print(load_msg[1])
|
| 64 |
-
|
| 65 |
-
save_state_dict_1 = {}
|
| 66 |
-
for key in state_dict:
|
| 67 |
-
if ".blocks." in key:
|
| 68 |
-
_, _, layer_num, *remain_names = key.split(".")
|
| 69 |
-
target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 70 |
-
else:
|
| 71 |
-
target_key = key
|
| 72 |
-
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 73 |
-
|
| 74 |
-
load_msg = model.lang_encoder.load_state_dict(
|
| 75 |
-
save_state_dict_1,
|
| 76 |
-
False,
|
| 77 |
-
)
|
| 78 |
-
# print incompatible keys
|
| 79 |
-
print(load_msg[1])
|
| 80 |
-
if args.flamingo_dir is not None:
|
| 81 |
-
state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
|
| 82 |
-
save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)
|
| 83 |
-
real_vocab_size = config.text_config.vocab_size
|
| 84 |
-
# Reshape the token embedding to 50280 for compatible
|
| 85 |
-
model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])
|
| 86 |
-
|
| 87 |
-
load_msg = model.load_state_dict(
|
| 88 |
-
save_state_dict_2,
|
| 89 |
-
False,
|
| 90 |
-
)
|
| 91 |
-
# print incompatible keys
|
| 92 |
-
print(load_msg[1])
|
| 93 |
-
# Reshape the token embedding to 50432
|
| 94 |
-
model.lang_encoder.resize_token_embeddings(real_vocab_size)
|
| 95 |
-
|
| 96 |
-
print(f"Saving model to {save_path}...")
|
| 97 |
-
model.save_pretrained(save_path, max_shard_size="10GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/injecting_mpt_into_flamingo.py
DELETED
|
@@ -1,109 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
|
| 7 |
-
import sys
|
| 8 |
-
|
| 9 |
-
from configuration_flamingo import FlamingoConfig
|
| 10 |
-
from modeling_flamingo import FlamingoForConditionalGeneration
|
| 11 |
-
from utils import rename_flamingo_checkpoint
|
| 12 |
-
|
| 13 |
-
parser = argparse.ArgumentParser(description="Convert MPT model")
|
| 14 |
-
parser.add_argument("--model_choice", type=str, choices=["7B", "30B"], required=True, help="Choose either '7B' or '30B'")
|
| 15 |
-
parser.add_argument("--mpt_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 16 |
-
parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 17 |
-
parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
|
| 18 |
-
args = parser.parse_args()
|
| 19 |
-
|
| 20 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 21 |
-
|
| 22 |
-
root_dir = args.mpt_root_dir
|
| 23 |
-
model_choice = args.model_choice
|
| 24 |
-
save_root_dir = args.save_root_dir
|
| 25 |
-
|
| 26 |
-
# prepare mpt model at first
|
| 27 |
-
# you can visit https://huggingface.co/mosaicml to download 7B and 30B instruct checkpoints.
|
| 28 |
-
if model_choice == "30B":
|
| 29 |
-
config_file = "./flamingo/flamingo-mpt-30B.json"
|
| 30 |
-
state_dict_files = [
|
| 31 |
-
f"{root_dir}/mpt-30b-instruct/pytorch_model-00001-of-00007.bin",
|
| 32 |
-
f"{root_dir}/mpt-30b-instruct/pytorch_model-00002-of-00007.bin",
|
| 33 |
-
f"{root_dir}/mpt-30b-instruct/pytorch_model-00003-of-00007.bin",
|
| 34 |
-
f"{root_dir}/mpt-30b-instruct/pytorch_model-00004-of-00007.bin",
|
| 35 |
-
f"{root_dir}/mpt-30b-instruct/pytorch_model-00005-of-00007.bin",
|
| 36 |
-
f"{root_dir}/mpt-30b-instruct/pytorch_model-00006-of-00007.bin",
|
| 37 |
-
f"{root_dir}/mpt-30b-instruct/pytorch_model-00007-of-00007.bin",
|
| 38 |
-
]
|
| 39 |
-
save_path = f"{save_root_dir}/flamingo-mpt-30B-instruct-init"
|
| 40 |
-
elif model_choice == "7B":
|
| 41 |
-
config_file = "./flamingo/flamingo-mpt-7B.json"
|
| 42 |
-
state_dict_files = [
|
| 43 |
-
f"{root_dir}/mpt-7b/pytorch_model-00001-of-00002.bin",
|
| 44 |
-
f"{root_dir}/mpt-7b/pytorch_model-00002-of-00002.bin",
|
| 45 |
-
]
|
| 46 |
-
save_path = f"{save_root_dir}/flamingo-mpt-7B"
|
| 47 |
-
else:
|
| 48 |
-
raise ValueError("Invalid model_choice. Choose either '30B' or '7B'.")
|
| 49 |
-
|
| 50 |
-
config = FlamingoConfig.from_json_file(config_file)
|
| 51 |
-
|
| 52 |
-
model = FlamingoForConditionalGeneration(config=config)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
# load flamingo's vision encoder from last checkpoint.
|
| 56 |
-
# you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
|
| 57 |
-
AZP = os.environ["AZP"]
|
| 58 |
-
state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 59 |
-
for cur_key in list(state_dict_3.keys()):
|
| 60 |
-
if "vision_encoder" not in cur_key:
|
| 61 |
-
del state_dict_3[cur_key]
|
| 62 |
-
|
| 63 |
-
load_msg = model.load_state_dict(
|
| 64 |
-
state_dict_3,
|
| 65 |
-
False,
|
| 66 |
-
)
|
| 67 |
-
# print incompatible keys
|
| 68 |
-
print(load_msg[1])
|
| 69 |
-
|
| 70 |
-
# Loading mpt weights
|
| 71 |
-
state_dict = {}
|
| 72 |
-
for file in tqdm(state_dict_files, desc="Loading state dict"):
|
| 73 |
-
state_dict_part = torch.load(file, map_location="cpu")
|
| 74 |
-
state_dict.update(state_dict_part)
|
| 75 |
-
|
| 76 |
-
save_state_dict_1 = {}
|
| 77 |
-
for key in state_dict:
|
| 78 |
-
if ".blocks." in key:
|
| 79 |
-
_, _, layer_num, *remain_names = key.split(".")
|
| 80 |
-
target_key = f"transformer.blocks.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 81 |
-
else:
|
| 82 |
-
target_key = key
|
| 83 |
-
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 84 |
-
|
| 85 |
-
load_msg = model.lang_encoder.load_state_dict(
|
| 86 |
-
save_state_dict_1,
|
| 87 |
-
False,
|
| 88 |
-
)
|
| 89 |
-
# print incompatible keys
|
| 90 |
-
print(load_msg[1])
|
| 91 |
-
if args.flamingo_dir is not None:
|
| 92 |
-
state_dict_2 = torch.load(f"{args.flamingo_dir}/checkpoint.pt", map_location="cpu")
|
| 93 |
-
save_state_dict_2 = rename_flamingo_checkpoint(state_dict_2)
|
| 94 |
-
|
| 95 |
-
real_vocab_size = config.text_config.vocab_size
|
| 96 |
-
# Reshape the token embedding to 50280 for compatible
|
| 97 |
-
model.lang_encoder.resize_token_embeddings(save_state_dict_2["lang_encoder.transformer.wte.weight"].shape[0])
|
| 98 |
-
|
| 99 |
-
load_msg = model.load_state_dict(
|
| 100 |
-
save_state_dict_2,
|
| 101 |
-
False,
|
| 102 |
-
)
|
| 103 |
-
# print incompatible keys
|
| 104 |
-
print(load_msg[1])
|
| 105 |
-
# Reshape the token embedding to 50432
|
| 106 |
-
model.lang_encoder.resize_token_embeddings(real_vocab_size)
|
| 107 |
-
|
| 108 |
-
print(f"Saving model to {save_path}...")
|
| 109 |
-
model.save_pretrained(save_path, max_shard_size="10GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/injecting_vicuna_into_flamingo.py
DELETED
|
@@ -1,100 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
|
| 7 |
-
import sys
|
| 8 |
-
|
| 9 |
-
from .configuration_flamingo import FlamingoConfig
|
| 10 |
-
from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 11 |
-
|
| 12 |
-
# from .configuration_flamingo import FlamingoConfig
|
| 13 |
-
# from .modeling_flamingo import FlamingoForConditionalGeneration
|
| 14 |
-
|
| 15 |
-
parser = argparse.ArgumentParser(description="Convert Vicuna model")
|
| 16 |
-
parser.add_argument("--model_choice", type=str, choices=["7B", "33B"], required=True, help="Choose either '7B' or '33B'")
|
| 17 |
-
parser.add_argument("--vicuna_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 18 |
-
parser.add_argument("--save_root_dir", type=str, default="/home/luodian/projects/checkpoints")
|
| 19 |
-
parser.add_argument("--flamingo_dir", type=str, default=None, help="If the pretrained flamingo weights also need to be injected")
|
| 20 |
-
args = parser.parse_args()
|
| 21 |
-
|
| 22 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 23 |
-
|
| 24 |
-
root_dir = args.vicuna_root_dir
|
| 25 |
-
model_choice = args.model_choice
|
| 26 |
-
save_root_dir = args.save_root_dir
|
| 27 |
-
|
| 28 |
-
# prepare vicuna model at first
|
| 29 |
-
# you can visit https://huggingface.co/lmsys/vicuna-33b-v1.3 to download 7B and 30B instruct checkpoints.
|
| 30 |
-
if model_choice == "33B":
|
| 31 |
-
config_file = "./flamingo/flamingo-vicuna-33B-v1.3.json"
|
| 32 |
-
state_dict_files = [
|
| 33 |
-
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00001-of-00007.bin",
|
| 34 |
-
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00002-of-00007.bin",
|
| 35 |
-
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00003-of-00007.bin",
|
| 36 |
-
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00004-of-00007.bin",
|
| 37 |
-
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00005-of-00007.bin",
|
| 38 |
-
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00006-of-00007.bin",
|
| 39 |
-
f"{root_dir}/vicuna-33b-v1.3/pytorch_model-00007-of-00007.bin",
|
| 40 |
-
]
|
| 41 |
-
save_path = f"{save_root_dir}/flamingo-vicuna-33B-v1.3-init"
|
| 42 |
-
elif model_choice == "7B":
|
| 43 |
-
config_file = "./flamingo/flamingo-vicuna-7B-v1.3.json"
|
| 44 |
-
state_dict_files = [
|
| 45 |
-
f"{root_dir}/vicuna-7b-v1.3/pytorch_model-00001-of-00002.bin",
|
| 46 |
-
f"{root_dir}/vicuna-7b-v1.3/pytorch_model-00002-of-00002.bin",
|
| 47 |
-
]
|
| 48 |
-
save_path = f"{save_root_dir}/flamingo-vicuna-7B-v1.3-init"
|
| 49 |
-
else:
|
| 50 |
-
raise ValueError("Invalid model_choice. Choose either '33B' or '7B'.")
|
| 51 |
-
|
| 52 |
-
config = FlamingoConfig.from_json_file(config_file)
|
| 53 |
-
model = FlamingoForConditionalGeneration(config=config)
|
| 54 |
-
|
| 55 |
-
# load flamingo's vision encoder from last checkpoint.
|
| 56 |
-
# you can visit https://huggingface.co/luodian/openflamingo-9b-hf/tree/main to download the checkpoint.
|
| 57 |
-
# AZP = "os.environ["AZP"]"
|
| 58 |
-
AZP = os.environ["AZP"]
|
| 59 |
-
state_dict_3 = torch.load(f"{AZP}/otter/checkpoints/flamingo_9b_hf/pytorch_model-00004-of-00004.bin", map_location="cpu")
|
| 60 |
-
for cur_key in list(state_dict_3.keys()):
|
| 61 |
-
if "vision_encoder" not in cur_key:
|
| 62 |
-
del state_dict_3[cur_key]
|
| 63 |
-
|
| 64 |
-
load_msg = model.load_state_dict(
|
| 65 |
-
state_dict_3,
|
| 66 |
-
False,
|
| 67 |
-
)
|
| 68 |
-
# print incompatible keys
|
| 69 |
-
print(load_msg[1])
|
| 70 |
-
|
| 71 |
-
# Loading vicuna weights
|
| 72 |
-
state_dict = {}
|
| 73 |
-
for file in tqdm(state_dict_files, desc="Loading state dict"):
|
| 74 |
-
state_dict_part = torch.load(file, map_location="cpu")
|
| 75 |
-
state_dict.update(state_dict_part)
|
| 76 |
-
|
| 77 |
-
save_state_dict_1 = {}
|
| 78 |
-
for key in state_dict:
|
| 79 |
-
if ".layers." in key:
|
| 80 |
-
_, _, layer_num, *remain_names = key.split(".")
|
| 81 |
-
target_key = f"model.layers.{layer_num}.decoder_layer.{'.'.join(remain_names)}"
|
| 82 |
-
else:
|
| 83 |
-
target_key = key
|
| 84 |
-
save_state_dict_1[f"{target_key}"] = state_dict[key]
|
| 85 |
-
|
| 86 |
-
# Reshape the token embedding to 50280 for compatible
|
| 87 |
-
model.lang_encoder.resize_token_embeddings(32000)
|
| 88 |
-
|
| 89 |
-
load_msg = model.lang_encoder.load_state_dict(
|
| 90 |
-
save_state_dict_1,
|
| 91 |
-
False,
|
| 92 |
-
)
|
| 93 |
-
# Reshape the token embedding to 32002 for compatible
|
| 94 |
-
model.lang_encoder.resize_token_embeddings(32002)
|
| 95 |
-
# print incompatible keys
|
| 96 |
-
print(load_msg[1])
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
print(f"Saving model to {save_path}...")
|
| 100 |
-
model.save_pretrained(save_path, max_shard_size="10GB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/modeling_flamingo.py
DELETED
|
@@ -1,966 +0,0 @@
|
|
| 1 |
-
import random
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
from typing import Callable, Optional
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
| 8 |
-
from einops import rearrange, repeat
|
| 9 |
-
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
|
| 10 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 11 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
-
from transformers.models.auto import AutoModel, AutoModelForCausalLM, AutoTokenizer
|
| 13 |
-
|
| 14 |
-
from .configuration_flamingo import FlamingoConfig
|
| 15 |
-
from .falcon.modelling_RW import RWForCausalLM
|
| 16 |
-
from .mpt.modeling_mpt import MPTForCausalLM
|
| 17 |
-
from .mpt_redpajama.mosaic_gpt import MosaicGPT
|
| 18 |
-
|
| 19 |
-
# from .configuration_flamingo import FlamingoConfig
|
| 20 |
-
|
| 21 |
-
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
| 22 |
-
"opt": "model.decoder.layers",
|
| 23 |
-
"gptneo": "transformer.h",
|
| 24 |
-
"gptj": "transformer.h",
|
| 25 |
-
"gpt-j": "transformer.h",
|
| 26 |
-
"pythia": "gpt_neox.layers",
|
| 27 |
-
"llama": "model.layers",
|
| 28 |
-
"RWForCausalLM": "transformer.h",
|
| 29 |
-
"MPTForCausalLM": "transformer.blocks",
|
| 30 |
-
"MosaicGPT": "transformer.blocks",
|
| 31 |
-
}
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _infer_decoder_layers_attr_name(model: nn.Module):
|
| 35 |
-
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
| 36 |
-
if k.lower() in model.__class__.__name__.lower():
|
| 37 |
-
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
| 38 |
-
|
| 39 |
-
raise ValueError(
|
| 40 |
-
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def extend_instance(obj, mixin):
|
| 45 |
-
"""Apply mixins to a class instance after creation"""
|
| 46 |
-
base_cls = obj.__class__
|
| 47 |
-
base_cls_name = obj.__class__.__name__
|
| 48 |
-
obj.__class__ = type(base_cls_name, (mixin, base_cls), {}) # mixin needs to go first for our forward() logic to work
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def getattr_recursive(obj, att):
|
| 52 |
-
"""
|
| 53 |
-
Return nested attribute of obj
|
| 54 |
-
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
| 55 |
-
"""
|
| 56 |
-
if att == "":
|
| 57 |
-
return obj
|
| 58 |
-
i = att.find(".")
|
| 59 |
-
if i < 0:
|
| 60 |
-
return getattr(obj, att)
|
| 61 |
-
else:
|
| 62 |
-
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def setattr_recursive(obj, att, val):
|
| 66 |
-
"""
|
| 67 |
-
Set nested attribute of obj
|
| 68 |
-
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
| 69 |
-
"""
|
| 70 |
-
if "." in att:
|
| 71 |
-
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
| 72 |
-
setattr(obj, att.split(".")[-1], val)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def exists(val):
|
| 76 |
-
return val is not None
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
class FlamingoPerceiverBlock(nn.Module):
|
| 80 |
-
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4):
|
| 81 |
-
super().__init__()
|
| 82 |
-
self.scale = dim_head**-0.5
|
| 83 |
-
self.heads = heads
|
| 84 |
-
inner_dim = dim_head * heads
|
| 85 |
-
ff_dim = dim * mult
|
| 86 |
-
self.norm_media = nn.LayerNorm(dim)
|
| 87 |
-
self.norm_latents = nn.LayerNorm(dim)
|
| 88 |
-
|
| 89 |
-
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 90 |
-
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| 91 |
-
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 92 |
-
self.feed_forward = nn.ModuleList(
|
| 93 |
-
[
|
| 94 |
-
nn.LayerNorm(dim),
|
| 95 |
-
nn.Linear(dim, ff_dim, bias=False),
|
| 96 |
-
nn.GELU(),
|
| 97 |
-
nn.Linear(ff_dim, dim, bias=False),
|
| 98 |
-
]
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
|
| 102 |
-
"""
|
| 103 |
-
Args:
|
| 104 |
-
x (torch.Tensor): image features
|
| 105 |
-
shape (b, T, n1, D)
|
| 106 |
-
latent (torch.Tensor): latent features
|
| 107 |
-
shape (b, T, n2, D)
|
| 108 |
-
"""
|
| 109 |
-
x = self.norm_media(x)
|
| 110 |
-
residual_latents = latents
|
| 111 |
-
latents = self.norm_latents(latents)
|
| 112 |
-
|
| 113 |
-
h = self.heads
|
| 114 |
-
|
| 115 |
-
q = self.to_q(latents)
|
| 116 |
-
kv_input = torch.cat((x, latents), dim=-2)
|
| 117 |
-
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 118 |
-
q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
|
| 119 |
-
k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
|
| 120 |
-
v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
|
| 121 |
-
q = q * self.scale
|
| 122 |
-
|
| 123 |
-
# attention
|
| 124 |
-
sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
|
| 125 |
-
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 126 |
-
attn = sim.softmax(dim=-1)
|
| 127 |
-
|
| 128 |
-
out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
|
| 129 |
-
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
| 130 |
-
out = self.to_out(out) + residual_latents
|
| 131 |
-
residual_out = out
|
| 132 |
-
for layer in self.feed_forward:
|
| 133 |
-
out = layer(out)
|
| 134 |
-
return out + residual_out
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class FlamingoPerceiverResampler(nn.Module):
|
| 138 |
-
def __init__(
|
| 139 |
-
self,
|
| 140 |
-
*,
|
| 141 |
-
dim: int,
|
| 142 |
-
depth: int = 6,
|
| 143 |
-
dim_head: int = 64,
|
| 144 |
-
heads: int = 8,
|
| 145 |
-
num_latents: int = 64,
|
| 146 |
-
# max_num_frames: int = 128,
|
| 147 |
-
max_num_media: Optional[int] = None,
|
| 148 |
-
max_num_frames: Optional[int] = None,
|
| 149 |
-
ff_mult: int = 4,
|
| 150 |
-
):
|
| 151 |
-
super().__init__()
|
| 152 |
-
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
| 153 |
-
self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
|
| 154 |
-
# self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim))
|
| 155 |
-
|
| 156 |
-
self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
|
| 157 |
-
|
| 158 |
-
self.layers = nn.ModuleList([])
|
| 159 |
-
for _ in range(depth):
|
| 160 |
-
self.layers.append(FlamingoPerceiverBlock(dim=dim, dim_head=dim_head, heads=heads, mult=ff_mult))
|
| 161 |
-
|
| 162 |
-
self.norm = nn.LayerNorm(dim)
|
| 163 |
-
|
| 164 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 165 |
-
"""
|
| 166 |
-
Args:
|
| 167 |
-
x (torch.Tensor): image features
|
| 168 |
-
shape (b, T, F, v, D)
|
| 169 |
-
Returns:
|
| 170 |
-
shape (b, T, n, D) where n is self.num_latents
|
| 171 |
-
"""
|
| 172 |
-
b, T, F, v = x.shape[:4]
|
| 173 |
-
|
| 174 |
-
# frame and media time embeddings
|
| 175 |
-
if exists(self.frame_embs):
|
| 176 |
-
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
| 177 |
-
x = x + frame_embs
|
| 178 |
-
x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
|
| 179 |
-
if exists(self.media_time_embs):
|
| 180 |
-
x = x + self.media_time_embs[:T]
|
| 181 |
-
|
| 182 |
-
# blocks
|
| 183 |
-
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
| 184 |
-
for block in self.layers:
|
| 185 |
-
latents = block(x, latents)
|
| 186 |
-
return self.norm(latents)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
class FlamingoMaskedCrossAttention(nn.Module):
|
| 190 |
-
def __init__(
|
| 191 |
-
self,
|
| 192 |
-
*,
|
| 193 |
-
dim: int,
|
| 194 |
-
dim_visual: int,
|
| 195 |
-
dim_head: int = 64,
|
| 196 |
-
heads: int = 8,
|
| 197 |
-
only_attend_immediate_media: bool = True,
|
| 198 |
-
):
|
| 199 |
-
super().__init__()
|
| 200 |
-
self.scale = dim_head**-0.5
|
| 201 |
-
self.heads = heads
|
| 202 |
-
inner_dim = dim_head * heads
|
| 203 |
-
|
| 204 |
-
self.norm = nn.LayerNorm(dim)
|
| 205 |
-
|
| 206 |
-
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 207 |
-
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
|
| 208 |
-
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 209 |
-
|
| 210 |
-
# whether for text to only attend to immediate preceding image, or all previous images
|
| 211 |
-
self.only_attend_immediate_media = only_attend_immediate_media
|
| 212 |
-
|
| 213 |
-
def forward(
|
| 214 |
-
self,
|
| 215 |
-
x: torch.Tensor,
|
| 216 |
-
media: torch.Tensor,
|
| 217 |
-
media_locations: Optional[torch.BoolTensor] = None,
|
| 218 |
-
attend_previous: bool = True,
|
| 219 |
-
) -> torch.Tensor:
|
| 220 |
-
"""
|
| 221 |
-
Args:
|
| 222 |
-
x (torch.Tensor): text features
|
| 223 |
-
shape (B, T_txt, D_txt)
|
| 224 |
-
media (torch.Tensor): image features
|
| 225 |
-
shape (B, T_img, n, D_img) where n is the dim of the latents
|
| 226 |
-
media_locations: boolean mask identifying the media tokens in x
|
| 227 |
-
shape (B, T_txt)
|
| 228 |
-
attend_previous: bool
|
| 229 |
-
If false, ignores immediately preceding image and starts attending when following image
|
| 230 |
-
"""
|
| 231 |
-
_, T_img, n = media.shape[:3]
|
| 232 |
-
h = self.heads
|
| 233 |
-
|
| 234 |
-
x = self.norm(x)
|
| 235 |
-
|
| 236 |
-
q = self.to_q(x)
|
| 237 |
-
media = rearrange(media, "b t n d -> b (t n) d")
|
| 238 |
-
|
| 239 |
-
k, v = self.to_kv(media).chunk(2, dim=-1)
|
| 240 |
-
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
| 241 |
-
k = rearrange(k, "b n (h d) -> b h n d", h=h)
|
| 242 |
-
v = rearrange(v, "b n (h d) -> b h n d", h=h)
|
| 243 |
-
|
| 244 |
-
q = q * self.scale
|
| 245 |
-
|
| 246 |
-
sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
|
| 247 |
-
|
| 248 |
-
if exists(media_locations):
|
| 249 |
-
# at each boolean of True, increment the time counter (relative to media time)
|
| 250 |
-
text_time = media_locations.cumsum(dim=-1)
|
| 251 |
-
media_time = torch.arange(T_img, device=x.device) + 1
|
| 252 |
-
|
| 253 |
-
if not attend_previous:
|
| 254 |
-
text_time[~media_locations] += 1
|
| 255 |
-
# make sure max is still the number of images in the sequence
|
| 256 |
-
text_time[
|
| 257 |
-
text_time
|
| 258 |
-
> repeat(
|
| 259 |
-
torch.count_nonzero(media_locations, dim=1),
|
| 260 |
-
"b -> b i",
|
| 261 |
-
i=text_time.shape[1],
|
| 262 |
-
)
|
| 263 |
-
] = 0
|
| 264 |
-
|
| 265 |
-
# text time must equal media time if only attending to most immediate image
|
| 266 |
-
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
|
| 267 |
-
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
|
| 268 |
-
|
| 269 |
-
text_to_media_mask = mask_op(
|
| 270 |
-
rearrange(text_time, "b i -> b 1 i 1"),
|
| 271 |
-
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
|
| 272 |
-
)
|
| 273 |
-
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
|
| 274 |
-
|
| 275 |
-
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 276 |
-
attn = sim.softmax(dim=-1)
|
| 277 |
-
|
| 278 |
-
if exists(media_locations) and self.only_attend_immediate_media:
|
| 279 |
-
# any text without a preceding media needs to have attention zeroed out
|
| 280 |
-
text_without_media_mask = text_time == 0
|
| 281 |
-
text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
|
| 282 |
-
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
| 283 |
-
|
| 284 |
-
out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
|
| 285 |
-
out = rearrange(out, "b h n d -> b n (h d)")
|
| 286 |
-
return self.to_out(out)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
class FlamingoGatedCrossAttentionBlock(nn.Module):
|
| 290 |
-
def __init__(
|
| 291 |
-
self,
|
| 292 |
-
*,
|
| 293 |
-
dim: int,
|
| 294 |
-
dim_visual: int,
|
| 295 |
-
dim_head: int = 64,
|
| 296 |
-
heads: int = 8,
|
| 297 |
-
ff_mult: int = 4,
|
| 298 |
-
only_attend_immediate_media: bool = True,
|
| 299 |
-
):
|
| 300 |
-
super().__init__()
|
| 301 |
-
self.attn = FlamingoMaskedCrossAttention(
|
| 302 |
-
dim=dim,
|
| 303 |
-
dim_visual=dim_visual,
|
| 304 |
-
dim_head=dim_head,
|
| 305 |
-
heads=heads,
|
| 306 |
-
only_attend_immediate_media=only_attend_immediate_media,
|
| 307 |
-
)
|
| 308 |
-
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
|
| 309 |
-
self.feed_forward = nn.ModuleList(
|
| 310 |
-
[
|
| 311 |
-
nn.LayerNorm(dim),
|
| 312 |
-
nn.Linear(dim, dim * ff_mult, bias=False),
|
| 313 |
-
nn.GELU(),
|
| 314 |
-
nn.Linear(dim * ff_mult, dim, bias=False),
|
| 315 |
-
]
|
| 316 |
-
)
|
| 317 |
-
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
|
| 318 |
-
|
| 319 |
-
def forward(
|
| 320 |
-
self,
|
| 321 |
-
x: torch.Tensor,
|
| 322 |
-
media: torch.Tensor,
|
| 323 |
-
media_locations: Optional[torch.BoolTensor] = None,
|
| 324 |
-
attend_previous: bool = True,
|
| 325 |
-
) -> torch.Tensor:
|
| 326 |
-
x = (
|
| 327 |
-
self.attn(
|
| 328 |
-
x,
|
| 329 |
-
media,
|
| 330 |
-
media_locations=media_locations,
|
| 331 |
-
attend_previous=attend_previous,
|
| 332 |
-
)
|
| 333 |
-
* self.attn_gate.tanh()
|
| 334 |
-
+ x
|
| 335 |
-
)
|
| 336 |
-
residual_x = x
|
| 337 |
-
for ff in self.feed_forward:
|
| 338 |
-
x = ff(x)
|
| 339 |
-
x = x * self.ff_gate.tanh() + residual_x
|
| 340 |
-
|
| 341 |
-
return x
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
class FlamingoLayer(nn.Module):
|
| 345 |
-
def __init__(self, gated_cross_attn_layer: nn.Module, decoder_layer: nn.Module):
|
| 346 |
-
super().__init__()
|
| 347 |
-
self.gated_cross_attn_layer = gated_cross_attn_layer
|
| 348 |
-
self.decoder_layer = decoder_layer
|
| 349 |
-
self.vis_x = None
|
| 350 |
-
self.media_locations = None
|
| 351 |
-
|
| 352 |
-
def is_conditioned(self) -> bool:
|
| 353 |
-
"""Check whether the layer is conditioned."""
|
| 354 |
-
return self.vis_x is not None
|
| 355 |
-
|
| 356 |
-
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
|
| 357 |
-
def condition_vis_x(self, vis_x) -> None:
|
| 358 |
-
self.vis_x = vis_x
|
| 359 |
-
|
| 360 |
-
def condition_media_locations(self, media_locations) -> None:
|
| 361 |
-
self.media_locations = media_locations
|
| 362 |
-
|
| 363 |
-
def condition_attend_previous(self, attend_previous) -> None:
|
| 364 |
-
self.attend_previous = attend_previous
|
| 365 |
-
|
| 366 |
-
def forward(
|
| 367 |
-
self,
|
| 368 |
-
lang_x: torch.Tensor,
|
| 369 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 370 |
-
**decoder_layer_kwargs,
|
| 371 |
-
):
|
| 372 |
-
if self.gated_cross_attn_layer is None:
|
| 373 |
-
return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
|
| 374 |
-
|
| 375 |
-
if self.vis_x is None:
|
| 376 |
-
raise ValueError("vis_x must be conditioned before forward pass")
|
| 377 |
-
|
| 378 |
-
if self.media_locations is None:
|
| 379 |
-
raise ValueError("media_locations must be conditioned before forward pass")
|
| 380 |
-
|
| 381 |
-
lang_x = self.gated_cross_attn_layer(
|
| 382 |
-
lang_x,
|
| 383 |
-
self.vis_x,
|
| 384 |
-
media_locations=self.media_locations,
|
| 385 |
-
attend_previous=self.attend_previous,
|
| 386 |
-
)
|
| 387 |
-
lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
|
| 388 |
-
return lang_x
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
class FlamingoLMMixin(nn.Module):
|
| 392 |
-
"""
|
| 393 |
-
Mixin to add cross-attention layers to a language model.
|
| 394 |
-
"""
|
| 395 |
-
|
| 396 |
-
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
| 397 |
-
self.decoder_layers_attr_name = decoder_layers_attr_name
|
| 398 |
-
|
| 399 |
-
def _get_decoder_layers(self):
|
| 400 |
-
return getattr_recursive(self, self.decoder_layers_attr_name)
|
| 401 |
-
|
| 402 |
-
def _set_decoder_layers(self, value):
|
| 403 |
-
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
| 404 |
-
|
| 405 |
-
def init_flamingo(
|
| 406 |
-
self,
|
| 407 |
-
media_token_id: int,
|
| 408 |
-
vis_hidden_size: int,
|
| 409 |
-
cross_attn_every_n_layers: int,
|
| 410 |
-
use_media_placement_augmentation: bool,
|
| 411 |
-
):
|
| 412 |
-
"""
|
| 413 |
-
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
|
| 414 |
-
"""
|
| 415 |
-
|
| 416 |
-
gated_cross_attn_layers = nn.ModuleList(
|
| 417 |
-
[
|
| 418 |
-
FlamingoGatedCrossAttentionBlock(
|
| 419 |
-
dim=self.config.hidden_size,
|
| 420 |
-
dim_visual=vis_hidden_size,
|
| 421 |
-
)
|
| 422 |
-
if (layer_idx + 1) % cross_attn_every_n_layers == 0
|
| 423 |
-
else None
|
| 424 |
-
for layer_idx, _ in enumerate(self._get_decoder_layers())
|
| 425 |
-
]
|
| 426 |
-
)
|
| 427 |
-
self._set_decoder_layers(
|
| 428 |
-
nn.ModuleList(
|
| 429 |
-
[
|
| 430 |
-
FlamingoLayer(gated_cross_attn_layer, decoder_layer)
|
| 431 |
-
for gated_cross_attn_layer, decoder_layer in zip(gated_cross_attn_layers, self._get_decoder_layers())
|
| 432 |
-
]
|
| 433 |
-
)
|
| 434 |
-
)
|
| 435 |
-
self.media_token_id = media_token_id
|
| 436 |
-
self.use_media_placement_augmentation = use_media_placement_augmentation
|
| 437 |
-
self.initialized_flamingo = True
|
| 438 |
-
|
| 439 |
-
def forward(self, *input, **kwargs):
|
| 440 |
-
"""Condition the Flamingo layers on the media locations before forward()"""
|
| 441 |
-
if not self.initialized_flamingo:
|
| 442 |
-
raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.")
|
| 443 |
-
|
| 444 |
-
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
|
| 445 |
-
media_locations = input_ids == self.media_token_id
|
| 446 |
-
# IMPORTANT: Force `attend_previous` to True when we place training data as <image>caption<|endofchunk|>
|
| 447 |
-
# attend_previous = (
|
| 448 |
-
# (random.random() < 0.5) if self.use_media_placement_augmentation else False
|
| 449 |
-
# )
|
| 450 |
-
attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else True
|
| 451 |
-
# attend_previous = self.only_attend_previous
|
| 452 |
-
|
| 453 |
-
if self.__class__.__name__ == "LlamaForCausalLM":
|
| 454 |
-
for layer in self.get_decoder().layers:
|
| 455 |
-
layer.condition_media_locations(media_locations)
|
| 456 |
-
layer.condition_attend_previous(attend_previous)
|
| 457 |
-
elif self.__class__.__name__ in ["MPTForCausalLM", "MosaicGPT"]:
|
| 458 |
-
for layer in self.get_decoder().blocks:
|
| 459 |
-
layer.condition_media_locations(media_locations)
|
| 460 |
-
layer.condition_attend_previous(attend_previous)
|
| 461 |
-
else:
|
| 462 |
-
print("inavaliable text encoder")
|
| 463 |
-
return super().forward(*input, **kwargs) # Call the other parent's forward method
|
| 464 |
-
|
| 465 |
-
def is_conditioned(self) -> bool:
|
| 466 |
-
"""Check whether all decoder layers are already conditioned."""
|
| 467 |
-
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
| 468 |
-
|
| 469 |
-
def clear_conditioned_layers(self) -> None:
|
| 470 |
-
for layer in self._get_decoder_layers():
|
| 471 |
-
layer.condition_vis_x(None)
|
| 472 |
-
layer.condition_media_locations(None)
|
| 473 |
-
layer.condition_attend_previous(None)
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
class FlamingoPreTrainedModel(PreTrainedModel):
|
| 477 |
-
"""
|
| 478 |
-
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 479 |
-
models.
|
| 480 |
-
"""
|
| 481 |
-
|
| 482 |
-
config_class = FlamingoConfig
|
| 483 |
-
base_model_prefix = "flamingo"
|
| 484 |
-
supports_gradient_checkpointing = True
|
| 485 |
-
_no_split_modules = ["FlamingoPerceiverBlock", "CLIPEncoderLayer", "FlamingoLayer"]
|
| 486 |
-
|
| 487 |
-
def _init_weights(self, module):
|
| 488 |
-
"""Flamingo requires no specific initialization"""
|
| 489 |
-
return super()._init_weights(module)
|
| 490 |
-
|
| 491 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
| 492 |
-
if isinstance(module, FlamingoModel):
|
| 493 |
-
module.gradient_checkpointing = value
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
class FlamingoModel(FlamingoPreTrainedModel):
|
| 497 |
-
config_class = FlamingoConfig
|
| 498 |
-
|
| 499 |
-
def __init__(
|
| 500 |
-
self,
|
| 501 |
-
config: FlamingoConfig,
|
| 502 |
-
):
|
| 503 |
-
super().__init__(config)
|
| 504 |
-
### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
|
| 505 |
-
if "llama" not in config.text_config._name_or_path:
|
| 506 |
-
if config.text_config.architectures[0] == "MPTForCausalLM":
|
| 507 |
-
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
|
| 508 |
-
lang_encoder = MPTForCausalLM(config=config.text_config)
|
| 509 |
-
elif config.text_config.text_config.architectures[0] == "MosaicGPT":
|
| 510 |
-
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
|
| 511 |
-
lang_encoder = MosaicGPT(config=config.text_config)
|
| 512 |
-
elif config.text_config.architectures[0] == "RWForCausalLM":
|
| 513 |
-
text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
|
| 514 |
-
lang_encoder = RWForCausalLM(config=config.text_config)
|
| 515 |
-
else:
|
| 516 |
-
text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
|
| 517 |
-
lang_encoder = LlamaForCausalLM(config=config.text_config)
|
| 518 |
-
|
| 519 |
-
vision_encoder = CLIPVisionModel(config=config.vision_config)
|
| 520 |
-
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
|
| 521 |
-
if text_tokenizer.pad_token is None:
|
| 522 |
-
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 523 |
-
self.text_tokenizer = text_tokenizer
|
| 524 |
-
self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
|
| 525 |
-
self.media_token_id = text_tokenizer.encode("<image>")[-1]
|
| 526 |
-
|
| 527 |
-
extend_instance(lang_encoder, FlamingoLMMixin)
|
| 528 |
-
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
| 529 |
-
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
| 530 |
-
if lang_encoder.__class__.__name__ == "LlamaForCausalLM":
|
| 531 |
-
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
| 532 |
-
self.lang_encoder = lang_encoder
|
| 533 |
-
|
| 534 |
-
self.cross_attn_every_n_layers = config.cross_attn_every_n_layers if hasattr(config, "cross_attn_every_n_layers") else 4
|
| 535 |
-
self.use_media_placement_augmentation = config.use_media_placement_augmentation
|
| 536 |
-
|
| 537 |
-
vision_encoder.output_tokens = True
|
| 538 |
-
self.vision_encoder = vision_encoder
|
| 539 |
-
|
| 540 |
-
self.vis_dim = 1024
|
| 541 |
-
self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
|
| 542 |
-
|
| 543 |
-
self.lang_encoder.init_flamingo(
|
| 544 |
-
media_token_id=self.media_token_id,
|
| 545 |
-
vis_hidden_size=self.vis_dim,
|
| 546 |
-
cross_attn_every_n_layers=self.cross_attn_every_n_layers,
|
| 547 |
-
use_media_placement_augmentation=self.use_media_placement_augmentation,
|
| 548 |
-
)
|
| 549 |
-
self.post_init()
|
| 550 |
-
|
| 551 |
-
def get_input_embeddings(self) -> nn.Module:
|
| 552 |
-
return self.lang_encoder.get_input_embeddings()
|
| 553 |
-
|
| 554 |
-
def set_input_embeddings(self, new_embeddings):
|
| 555 |
-
self.lang_encoder.set_input_embeddings(new_embeddings)
|
| 556 |
-
|
| 557 |
-
def get_output_embeddings(self) -> nn.Module:
|
| 558 |
-
return self.lang_encoder.get_output_embeddings()
|
| 559 |
-
|
| 560 |
-
def set_output_embeddings(self, new_embeddings):
|
| 561 |
-
self.lang_encoder.set_output_embeddings(new_embeddings)
|
| 562 |
-
|
| 563 |
-
def get_image_encoder(self) -> nn.Module:
|
| 564 |
-
return self.vision_encoder
|
| 565 |
-
|
| 566 |
-
def get_lang_encoder(self) -> nn.Module:
|
| 567 |
-
return self.lang_encoder
|
| 568 |
-
|
| 569 |
-
# def init_weights(self):
|
| 570 |
-
# # Freeze all parameters in vision encoder
|
| 571 |
-
# for param in self.vision_encoder.parameters():
|
| 572 |
-
# param.requires_grad = False
|
| 573 |
-
# # Freeze all parameters in lang encoders except gated_cross_attn_layers
|
| 574 |
-
# for name, param in self.lang_encoder.named_parameters():
|
| 575 |
-
# if "gated_cross_attn_layer" not in name:
|
| 576 |
-
# param.requires_grad = False
|
| 577 |
-
# # Unfreeze LM input embeddings
|
| 578 |
-
# self.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 579 |
-
# ## MPTForCausalLM is tied word embedding
|
| 580 |
-
# if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
|
| 581 |
-
# self.lang_encoder.lm_head.requires_grad_(True)
|
| 582 |
-
# # assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 583 |
-
# # print model size in billions of parameters in 2 decimal places
|
| 584 |
-
# print(f"Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.2f} B")
|
| 585 |
-
|
| 586 |
-
def init_weights(self):
|
| 587 |
-
# Freeze all parameters in vision encoder
|
| 588 |
-
for param in self.vision_encoder.parameters():
|
| 589 |
-
param.requires_grad = False
|
| 590 |
-
|
| 591 |
-
if "lora_config" in self.config.__dict__:
|
| 592 |
-
print(f"LoRA trainable param: {(sum(p.numel() for p in self.lang_encoder.parameters() if p.requires_grad)) / 1e9:.3f} B")
|
| 593 |
-
# Unfreeze gated_cross_attn_layers
|
| 594 |
-
for layer in self.lang_encoder._get_decoder_layers():
|
| 595 |
-
if layer.gated_cross_attn_layer is not None:
|
| 596 |
-
for param in layer.gated_cross_attn_layer.parameters():
|
| 597 |
-
param.requires_grad = True
|
| 598 |
-
else:
|
| 599 |
-
# Freeze all parameters in lang encoders except gated_cross_attn_layers
|
| 600 |
-
for name, param in self.lang_encoder.named_parameters():
|
| 601 |
-
if "gated_cross_attn_layer" not in name:
|
| 602 |
-
param.requires_grad = False
|
| 603 |
-
# Unfreeze LM input and output embeddings
|
| 604 |
-
self.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 605 |
-
## MPTForCausalLM is tied word embedding
|
| 606 |
-
if self.lang_encoder.__class__.__name__ == "LlamaForCausalLM":
|
| 607 |
-
self.lang_encoder.lm_head.requires_grad_(True)
|
| 608 |
-
# assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 609 |
-
# print model size in billions of parameters in 2 decimal places
|
| 610 |
-
print(f"Total Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.3f} B")
|
| 611 |
-
|
| 612 |
-
def forward(
|
| 613 |
-
self,
|
| 614 |
-
vision_x: torch.Tensor,
|
| 615 |
-
lang_x: torch.Tensor,
|
| 616 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 617 |
-
labels: Optional[torch.Tensor] = None,
|
| 618 |
-
use_cached_vision_x: bool = False,
|
| 619 |
-
clear_conditioned_layers: bool = True,
|
| 620 |
-
past_key_values: Optional[torch.Tensor] = None,
|
| 621 |
-
use_cache: bool = False,
|
| 622 |
-
**kwargs,
|
| 623 |
-
) -> CausalLMOutputWithPast:
|
| 624 |
-
"""
|
| 625 |
-
Forward pass of Flamingo.
|
| 626 |
-
|
| 627 |
-
Args:
|
| 628 |
-
vision_x (torch.Tensor): Vision input
|
| 629 |
-
shape (B, T_img, F, C, H, W) with F=1
|
| 630 |
-
lang_x (torch.Tensor): Language input ids
|
| 631 |
-
shape (B, T_txt)
|
| 632 |
-
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
| 633 |
-
labels (torch.Tensor, optional): Labels. Defaults to None.
|
| 634 |
-
clear_conditioned_layers: if True, clear the conditioned layers
|
| 635 |
-
once the foward pass is completed. Set this to false if the
|
| 636 |
-
same set of images will be reused in another subsequent
|
| 637 |
-
forward pass.
|
| 638 |
-
past_key_values: pre-computed values to pass to language model.
|
| 639 |
-
See past_key_values documentation in Hugging Face
|
| 640 |
-
CausalLM models.
|
| 641 |
-
use_cache: whether to use cached key values. See use_cache
|
| 642 |
-
documentation in Hugging Face CausalLM models.
|
| 643 |
-
"""
|
| 644 |
-
assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
|
| 645 |
-
|
| 646 |
-
if use_cached_vision_x:
|
| 647 |
-
# Case: use cached; vision_x should be cached and other
|
| 648 |
-
# vision-related inputs should not be provided.
|
| 649 |
-
assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
|
| 650 |
-
assert self.lang_encoder.is_conditioned()
|
| 651 |
-
|
| 652 |
-
else:
|
| 653 |
-
# Case: do not use caching (i.e. this is a standard forward pass);
|
| 654 |
-
self._encode_vision_x(vision_x=vision_x)
|
| 655 |
-
|
| 656 |
-
output = self.lang_encoder(
|
| 657 |
-
input_ids=lang_x,
|
| 658 |
-
attention_mask=attention_mask,
|
| 659 |
-
labels=labels,
|
| 660 |
-
past_key_values=past_key_values,
|
| 661 |
-
use_cache=use_cache,
|
| 662 |
-
**kwargs,
|
| 663 |
-
)
|
| 664 |
-
|
| 665 |
-
if clear_conditioned_layers:
|
| 666 |
-
self.lang_encoder.clear_conditioned_layers()
|
| 667 |
-
|
| 668 |
-
return output
|
| 669 |
-
|
| 670 |
-
def _encode_vision_x(self, vision_x: torch.Tensor):
|
| 671 |
-
"""
|
| 672 |
-
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
| 673 |
-
Args:
|
| 674 |
-
vision_x (torch.Tensor): Vision input
|
| 675 |
-
shape (B, T_img, F, C, H, W)
|
| 676 |
-
Images in the same chunk are collated along T_img, and frames are collated along F
|
| 677 |
-
Currently only F=1 is supported (single-frame videos)
|
| 678 |
-
|
| 679 |
-
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
| 680 |
-
"""
|
| 681 |
-
|
| 682 |
-
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
| 683 |
-
b, T, F = vision_x.shape[:3]
|
| 684 |
-
assert F == 1, "Only single frame supported"
|
| 685 |
-
|
| 686 |
-
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
| 687 |
-
with torch.no_grad():
|
| 688 |
-
vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
|
| 689 |
-
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
| 690 |
-
|
| 691 |
-
vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
|
| 692 |
-
|
| 693 |
-
for layer in self.lang_encoder._get_decoder_layers():
|
| 694 |
-
layer.condition_vis_x(vision_x)
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
class FlamingoForConditionalGeneration(FlamingoPreTrainedModel):
|
| 698 |
-
config_class = FlamingoConfig
|
| 699 |
-
|
| 700 |
-
def __init__(
|
| 701 |
-
self,
|
| 702 |
-
config: FlamingoConfig,
|
| 703 |
-
):
|
| 704 |
-
super().__init__(config)
|
| 705 |
-
# TODO: hardcode right because autoXXX is too slow
|
| 706 |
-
# vision_encoder = AutoModel.from_config(config.vision_config).vision_model
|
| 707 |
-
# lang_encoder = AutoModelForCausalLM.from_config(config.text_config)
|
| 708 |
-
# text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path)
|
| 709 |
-
|
| 710 |
-
### TODO: give "LlamaForCausalLM" as the name of text_config.architectures of Llama_based flamingo
|
| 711 |
-
# assert hasattr(config.text_config, "_name_or_path")
|
| 712 |
-
# if "llama" not in config.text_config._name_or_path.lower():
|
| 713 |
-
if config.text_config.architectures[0] == "MPTForCausalLM":
|
| 714 |
-
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-instruct")
|
| 715 |
-
lang_encoder = MPTForCausalLM(config=config.text_config)
|
| 716 |
-
elif config.text_config.architectures[0] == "MosaicGPT":
|
| 717 |
-
text_tokenizer = AutoTokenizer.from_pretrained("mosaicml/mosaic-llama-redpajama-final-candidate")
|
| 718 |
-
lang_encoder = MosaicGPT(config=config.text_config)
|
| 719 |
-
elif config.text_config.architectures[0] == "RWForCausalLM":
|
| 720 |
-
text_tokenizer = AutoTokenizer.from_pretrained("PATH-TO-YOUR-FALCON")
|
| 721 |
-
lang_encoder = RWForCausalLM(config=config.text_config)
|
| 722 |
-
# TODO: what's the logic here?
|
| 723 |
-
elif config.text_config.architectures[0] == "LlamaForCausalLM":
|
| 724 |
-
text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
|
| 725 |
-
lang_encoder = LlamaForCausalLM(config=config.text_config)
|
| 726 |
-
else:
|
| 727 |
-
import pdb
|
| 728 |
-
|
| 729 |
-
pdb.set_trace()
|
| 730 |
-
# else:
|
| 731 |
-
# text_tokenizer = LlamaTokenizer.from_pretrained(config.text_config._name_or_path)
|
| 732 |
-
# lang_encoder = LlamaForCausalLM(config=config.text_config)
|
| 733 |
-
|
| 734 |
-
vision_encoder = CLIPVisionModel(config=config.vision_config)
|
| 735 |
-
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
|
| 736 |
-
if text_tokenizer.pad_token is None:
|
| 737 |
-
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 738 |
-
self.text_tokenizer = text_tokenizer
|
| 739 |
-
self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
|
| 740 |
-
self.media_token_id = text_tokenizer.encode("<image>")[-1]
|
| 741 |
-
|
| 742 |
-
extend_instance(lang_encoder, FlamingoLMMixin)
|
| 743 |
-
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
| 744 |
-
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
| 745 |
-
if "LlamaForCausalLM" in lang_encoder.__class__.__name__:
|
| 746 |
-
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
| 747 |
-
self.lang_encoder = lang_encoder
|
| 748 |
-
|
| 749 |
-
self.cross_attn_every_n_layers = config.cross_attn_every_n_layers if hasattr(config, "cross_attn_every_n_layers") else 4
|
| 750 |
-
self.use_media_placement_augmentation = config.use_media_placement_augmentation
|
| 751 |
-
|
| 752 |
-
vision_encoder.output_tokens = True
|
| 753 |
-
self.vision_encoder = vision_encoder
|
| 754 |
-
|
| 755 |
-
self.vis_dim = 1024
|
| 756 |
-
self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
|
| 757 |
-
|
| 758 |
-
self.lang_encoder.init_flamingo(
|
| 759 |
-
media_token_id=self.media_token_id,
|
| 760 |
-
vis_hidden_size=self.vis_dim,
|
| 761 |
-
cross_attn_every_n_layers=self.cross_attn_every_n_layers,
|
| 762 |
-
use_media_placement_augmentation=self.use_media_placement_augmentation,
|
| 763 |
-
)
|
| 764 |
-
self.post_init()
|
| 765 |
-
|
| 766 |
-
def get_input_embeddings(self) -> nn.Module:
|
| 767 |
-
return self.lang_encoder.get_input_embeddings()
|
| 768 |
-
|
| 769 |
-
def set_input_embeddings(self, new_embeddings):
|
| 770 |
-
self.lang_encoder.set_input_embeddings(new_embeddings)
|
| 771 |
-
|
| 772 |
-
def get_output_embeddings(self) -> nn.Module:
|
| 773 |
-
return self.lang_encoder.get_output_embeddings()
|
| 774 |
-
|
| 775 |
-
def set_output_embeddings(self, new_embeddings):
|
| 776 |
-
self.lang_encoder.set_output_embeddings(new_embeddings)
|
| 777 |
-
|
| 778 |
-
def get_image_encoder(self) -> nn.Module:
|
| 779 |
-
return self.vision_encoder
|
| 780 |
-
|
| 781 |
-
def get_lang_encoder(self) -> nn.Module:
|
| 782 |
-
return self.lang_encoder
|
| 783 |
-
|
| 784 |
-
def init_weights(self):
|
| 785 |
-
# Freeze all parameters in vision encoder
|
| 786 |
-
for param in self.vision_encoder.parameters():
|
| 787 |
-
param.requires_grad = False
|
| 788 |
-
# Freeze all parameters in lang encoders except gated_cross_attn_layers
|
| 789 |
-
for name, param in self.lang_encoder.named_parameters():
|
| 790 |
-
if "gated_cross_attn_layer" not in name:
|
| 791 |
-
param.requires_grad = False
|
| 792 |
-
# Unfreeze LM input embeddings
|
| 793 |
-
self.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 794 |
-
## MPTForCausalLM is tied word embedding
|
| 795 |
-
if "LlamaForCausalLM" in self.lang_encoder.__class__.__name__:
|
| 796 |
-
self.lang_encoder.lm_head.requires_grad_(True)
|
| 797 |
-
# assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 798 |
-
# print model size in billions of parameters in 2 decimal places
|
| 799 |
-
print("====================Model Grad Part====================")
|
| 800 |
-
total_params = 0
|
| 801 |
-
for name, param in self.named_parameters():
|
| 802 |
-
if param.requires_grad:
|
| 803 |
-
total_params += param.numel()
|
| 804 |
-
print(f"Parameter: {name}, Size: {param.numel() / 1e6:.6f} M")
|
| 805 |
-
print(f"Total Trainable param: {total_params / 1e9:.4f} B")
|
| 806 |
-
print(f"Total Trainable param: {(sum(p.numel() for p in self.parameters() if p.requires_grad)) / 1e9:.3f} B")
|
| 807 |
-
|
| 808 |
-
def forward(
|
| 809 |
-
self,
|
| 810 |
-
vision_x: torch.Tensor,
|
| 811 |
-
lang_x: torch.Tensor,
|
| 812 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 813 |
-
labels: Optional[torch.Tensor] = None,
|
| 814 |
-
use_cached_vision_x: bool = False,
|
| 815 |
-
clear_conditioned_layers: bool = True,
|
| 816 |
-
past_key_values: Optional[torch.Tensor] = None,
|
| 817 |
-
use_cache: bool = False,
|
| 818 |
-
**kwargs,
|
| 819 |
-
) -> CausalLMOutputWithPast:
|
| 820 |
-
"""
|
| 821 |
-
Forward pass of Flamingo.
|
| 822 |
-
|
| 823 |
-
Args:
|
| 824 |
-
vision_x (torch.Tensor): Vision input
|
| 825 |
-
shape (B, T_img, F, C, H, W) with F=1
|
| 826 |
-
lang_x (torch.Tensor): Language input ids
|
| 827 |
-
shape (B, T_txt)
|
| 828 |
-
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
| 829 |
-
labels (torch.Tensor, optional): Labels. Defaults to None.
|
| 830 |
-
clear_conditioned_layers: if True, clear the conditioned layers
|
| 831 |
-
once the foward pass is completed. Set this to false if the
|
| 832 |
-
same set of images will be reused in another subsequent
|
| 833 |
-
forward pass.
|
| 834 |
-
past_key_values: pre-computed values to pass to language model.
|
| 835 |
-
See past_key_values documentation in Hugging Face
|
| 836 |
-
CausalLM models.
|
| 837 |
-
use_cache: whether to use cached key values. See use_cache
|
| 838 |
-
documentation in Hugging Face CausalLM models.
|
| 839 |
-
"""
|
| 840 |
-
assert (vision_x is not None) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
|
| 841 |
-
|
| 842 |
-
if use_cached_vision_x:
|
| 843 |
-
# Case: use cached; vision_x should be cached and other
|
| 844 |
-
# vision-related inputs should not be provided.
|
| 845 |
-
assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
|
| 846 |
-
assert self.lang_encoder.is_conditioned()
|
| 847 |
-
|
| 848 |
-
else:
|
| 849 |
-
# Case: do not use caching (i.e. this is a standard forward pass);
|
| 850 |
-
self._encode_vision_x(vision_x=vision_x)
|
| 851 |
-
|
| 852 |
-
output = self.lang_encoder(
|
| 853 |
-
input_ids=lang_x,
|
| 854 |
-
attention_mask=attention_mask,
|
| 855 |
-
labels=labels,
|
| 856 |
-
past_key_values=past_key_values,
|
| 857 |
-
use_cache=use_cache,
|
| 858 |
-
**kwargs,
|
| 859 |
-
)
|
| 860 |
-
|
| 861 |
-
if clear_conditioned_layers:
|
| 862 |
-
self.lang_encoder.clear_conditioned_layers()
|
| 863 |
-
|
| 864 |
-
return output
|
| 865 |
-
|
| 866 |
-
def _encode_vision_x(self, vision_x: torch.Tensor):
|
| 867 |
-
"""
|
| 868 |
-
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
| 869 |
-
Args:
|
| 870 |
-
vision_x (torch.Tensor): Vision input
|
| 871 |
-
shape (B, T_img, F, C, H, W)
|
| 872 |
-
Images in the same chunk are collated along T_img, and frames are collated along F
|
| 873 |
-
Currently only F=1 is supported (single-frame videos)
|
| 874 |
-
|
| 875 |
-
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
| 876 |
-
"""
|
| 877 |
-
|
| 878 |
-
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
| 879 |
-
b, T, F = vision_x.shape[:3]
|
| 880 |
-
# assert F == 1, "Only single frame supported"
|
| 881 |
-
|
| 882 |
-
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
| 883 |
-
with torch.no_grad():
|
| 884 |
-
vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
|
| 885 |
-
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
| 886 |
-
|
| 887 |
-
vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
|
| 888 |
-
|
| 889 |
-
for layer in self.lang_encoder._get_decoder_layers():
|
| 890 |
-
layer.condition_vis_x(vision_x)
|
| 891 |
-
|
| 892 |
-
@torch.no_grad()
|
| 893 |
-
def generate(
|
| 894 |
-
self,
|
| 895 |
-
vision_x: torch.Tensor,
|
| 896 |
-
lang_x: torch.Tensor,
|
| 897 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 898 |
-
num_beams: int = 1,
|
| 899 |
-
max_new_tokens: Optional[int] = None,
|
| 900 |
-
temperature: float = 1.0,
|
| 901 |
-
top_k: int = 0,
|
| 902 |
-
top_p: float = 1.0,
|
| 903 |
-
no_repeat_ngram_size: int = 0,
|
| 904 |
-
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
| 905 |
-
length_penalty: float = 1.0,
|
| 906 |
-
num_return_sequences: int = 1,
|
| 907 |
-
do_sample: bool = False,
|
| 908 |
-
early_stopping: bool = False,
|
| 909 |
-
**kwargs,
|
| 910 |
-
):
|
| 911 |
-
"""
|
| 912 |
-
Generate text conditioned on vision and language inputs.
|
| 913 |
-
|
| 914 |
-
Args:
|
| 915 |
-
vision_x (torch.Tensor): Vision input
|
| 916 |
-
shape (B, T_img, F, C, H, W)
|
| 917 |
-
images in the same chunk are collated along T_img, and frames are collated along F
|
| 918 |
-
currently only F=1 is supported (single-frame videos)
|
| 919 |
-
lang_x (torch.Tensor): Language input
|
| 920 |
-
shape (B, T_txt)
|
| 921 |
-
max_length (int, optional): Maximum length of the output. Defaults to None.
|
| 922 |
-
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
| 923 |
-
num_beams (int, optional): Number of beams. Defaults to 1.
|
| 924 |
-
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
|
| 925 |
-
temperature (float, optional): Temperature. Defaults to 1.0.
|
| 926 |
-
top_k (int, optional): Top k. Defaults to 0.
|
| 927 |
-
top_p (float, optional): Top p. Defaults to 1.0.
|
| 928 |
-
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
|
| 929 |
-
length_penalty (float, optional): Length penalty. Defaults to 1.0.
|
| 930 |
-
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
|
| 931 |
-
do_sample (bool, optional): Do sample. Defaults to False.
|
| 932 |
-
early_stopping (bool, optional): Early stopping. Defaults to False.
|
| 933 |
-
Returns:
|
| 934 |
-
torch.Tensor: lang_x with generated tokens appended to it
|
| 935 |
-
"""
|
| 936 |
-
if hasattr(self, "_hf_hook"):
|
| 937 |
-
# add a hook to make sure that the output of lang_encoder is mapped to the same device as the lang_x
|
| 938 |
-
hook = AlignDevicesHook(
|
| 939 |
-
execution_device=lang_x.device,
|
| 940 |
-
io_same_device=True,
|
| 941 |
-
place_submodules=False,
|
| 942 |
-
)
|
| 943 |
-
add_hook_to_module(self.lang_encoder, hook)
|
| 944 |
-
if num_beams > 1:
|
| 945 |
-
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
|
| 946 |
-
self._encode_vision_x(vision_x=vision_x)
|
| 947 |
-
output = self.lang_encoder.generate(
|
| 948 |
-
lang_x,
|
| 949 |
-
attention_mask=attention_mask,
|
| 950 |
-
eos_token_id=self.eoc_token_id,
|
| 951 |
-
num_beams=num_beams,
|
| 952 |
-
max_new_tokens=max_new_tokens,
|
| 953 |
-
temperature=temperature,
|
| 954 |
-
top_k=top_k,
|
| 955 |
-
top_p=top_p,
|
| 956 |
-
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 957 |
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 958 |
-
length_penalty=length_penalty,
|
| 959 |
-
num_return_sequences=num_return_sequences,
|
| 960 |
-
do_sample=do_sample,
|
| 961 |
-
early_stopping=early_stopping,
|
| 962 |
-
**kwargs,
|
| 963 |
-
)
|
| 964 |
-
|
| 965 |
-
self.lang_encoder.clear_conditioned_layers()
|
| 966 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/__init__.py
DELETED
|
File without changes
|
mllm/flamingo/mpt/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (206 Bytes)
|
|
|
mllm/flamingo/mpt/__pycache__/attention.cpython-39.pyc
DELETED
|
Binary file (12.2 kB)
|
|
|
mllm/flamingo/mpt/__pycache__/blocks.cpython-39.pyc
DELETED
|
Binary file (2.81 kB)
|
|
|
mllm/flamingo/mpt/__pycache__/configuration_mpt.cpython-39.pyc
DELETED
|
Binary file (8.76 kB)
|
|
|
mllm/flamingo/mpt/__pycache__/custom_embedding.cpython-39.pyc
DELETED
|
Binary file (797 Bytes)
|
|
|
mllm/flamingo/mpt/__pycache__/flash_attn_triton.cpython-39.pyc
DELETED
|
Binary file (20.9 kB)
|
|
|
mllm/flamingo/mpt/__pycache__/modeling_mpt.cpython-39.pyc
DELETED
|
Binary file (15.3 kB)
|
|
|
mllm/flamingo/mpt/__pycache__/norm.cpython-39.pyc
DELETED
|
Binary file (3.03 kB)
|
|
|
mllm/flamingo/mpt/__pycache__/param_init_fns.cpython-39.pyc
DELETED
|
Binary file (9.14 kB)
|
|
|
mllm/flamingo/mpt/adapt_tokenizer.py
DELETED
|
@@ -1,44 +0,0 @@
|
|
| 1 |
-
from typing import Union
|
| 2 |
-
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
| 3 |
-
|
| 4 |
-
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
| 5 |
-
NUM_SENTINEL_TOKENS: int = 100
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
|
| 9 |
-
"""Adds sentinel tokens and padding token (if missing).
|
| 10 |
-
|
| 11 |
-
Expands the tokenizer vocabulary to include sentinel tokens
|
| 12 |
-
used in mixture-of-denoiser tasks as well as a padding token.
|
| 13 |
-
|
| 14 |
-
All added tokens are added as special tokens. No tokens are
|
| 15 |
-
added if sentinel tokens and padding token already exist.
|
| 16 |
-
"""
|
| 17 |
-
sentinels_to_add = [f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)]
|
| 18 |
-
tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
|
| 19 |
-
if tokenizer.pad_token is None:
|
| 20 |
-
tokenizer.add_tokens("<pad>", special_tokens=True)
|
| 21 |
-
tokenizer.pad_token = "<pad>"
|
| 22 |
-
assert tokenizer.pad_token_id is not None
|
| 23 |
-
sentinels = "".join([f"<extra_id_{i}>" for i in range(NUM_SENTINEL_TOKENS)])
|
| 24 |
-
_sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
|
| 25 |
-
tokenizer.sentinel_token_ids = _sentinel_token_ids
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
class AutoTokenizerForMOD(AutoTokenizer):
|
| 29 |
-
"""AutoTokenizer + Adaptation for MOD.
|
| 30 |
-
|
| 31 |
-
A simple wrapper around AutoTokenizer to make instantiating
|
| 32 |
-
an MOD-adapted tokenizer a bit easier.
|
| 33 |
-
|
| 34 |
-
MOD-adapted tokenizers have sentinel tokens (e.g., <extra_id_0>),
|
| 35 |
-
a padding token, and a property to get the token ids of the
|
| 36 |
-
sentinel tokens.
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
@classmethod
|
| 40 |
-
def from_pretrained(cls, *args, **kwargs):
|
| 41 |
-
"""See `AutoTokenizer.from_pretrained` docstring."""
|
| 42 |
-
tokenizer = super().from_pretrained(*args, **kwargs)
|
| 43 |
-
adapt_tokenizer_for_denoising(tokenizer)
|
| 44 |
-
return tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/attention.py
DELETED
|
@@ -1,450 +0,0 @@
|
|
| 1 |
-
"""Attention layers."""
|
| 2 |
-
import math
|
| 3 |
-
import warnings
|
| 4 |
-
from typing import Optional
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
from einops import rearrange
|
| 8 |
-
from packaging import version
|
| 9 |
-
from torch import nn
|
| 10 |
-
from .norm import LPLayerNorm
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
|
| 14 |
-
if original_is_causal and num_query_tokens != num_key_tokens:
|
| 15 |
-
if num_query_tokens != 1:
|
| 16 |
-
raise NotImplementedError("MPT does not support query and key with different number of tokens, unless number of query tokens is 1.")
|
| 17 |
-
else:
|
| 18 |
-
return False
|
| 19 |
-
return original_is_causal
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def scaled_multihead_dot_product_attention(
|
| 23 |
-
query,
|
| 24 |
-
key,
|
| 25 |
-
value,
|
| 26 |
-
n_heads,
|
| 27 |
-
past_key_value=None,
|
| 28 |
-
softmax_scale=None,
|
| 29 |
-
attn_bias=None,
|
| 30 |
-
key_padding_mask=None,
|
| 31 |
-
is_causal=False,
|
| 32 |
-
dropout_p=0.0,
|
| 33 |
-
training=False,
|
| 34 |
-
needs_weights=False,
|
| 35 |
-
multiquery=False,
|
| 36 |
-
):
|
| 37 |
-
q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
|
| 38 |
-
kv_n_heads = 1 if multiquery else n_heads
|
| 39 |
-
k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
|
| 40 |
-
v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
|
| 41 |
-
if past_key_value is not None:
|
| 42 |
-
if len(past_key_value) != 0:
|
| 43 |
-
k = torch.cat([past_key_value[0], k], dim=3)
|
| 44 |
-
v = torch.cat([past_key_value[1], v], dim=2)
|
| 45 |
-
past_key_value = (k, v)
|
| 46 |
-
(b, _, s_q, d) = q.shape
|
| 47 |
-
s_k = k.size(-1)
|
| 48 |
-
if softmax_scale is None:
|
| 49 |
-
softmax_scale = 1 / math.sqrt(d)
|
| 50 |
-
attn_weight = q.matmul(k) * softmax_scale
|
| 51 |
-
if attn_bias is not None:
|
| 52 |
-
_s_q = max(0, attn_bias.size(2) - s_q)
|
| 53 |
-
_s_k = max(0, attn_bias.size(3) - s_k)
|
| 54 |
-
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 55 |
-
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
|
| 56 |
-
raise RuntimeError(f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.")
|
| 57 |
-
attn_weight = attn_weight + attn_bias
|
| 58 |
-
min_val = torch.finfo(q.dtype).min
|
| 59 |
-
if key_padding_mask is not None:
|
| 60 |
-
if attn_bias is not None:
|
| 61 |
-
warnings.warn(
|
| 62 |
-
"Propogating key_padding_mask to the attention module "
|
| 63 |
-
+ "and applying it within the attention module can cause "
|
| 64 |
-
+ "unneccessary computation/memory usage. Consider integrating "
|
| 65 |
-
+ "into attn_bias once and passing that to each attention "
|
| 66 |
-
+ "module instead."
|
| 67 |
-
)
|
| 68 |
-
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
|
| 69 |
-
if is_causal and (not q.size(2) == 1):
|
| 70 |
-
s = max(s_q, s_k)
|
| 71 |
-
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
|
| 72 |
-
causal_mask = causal_mask.tril()
|
| 73 |
-
causal_mask = causal_mask.to(torch.bool)
|
| 74 |
-
causal_mask = ~causal_mask
|
| 75 |
-
causal_mask = causal_mask[-s_q:, -s_k:]
|
| 76 |
-
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
|
| 77 |
-
attn_weight = torch.softmax(attn_weight, dim=-1)
|
| 78 |
-
if dropout_p:
|
| 79 |
-
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
|
| 80 |
-
out = attn_weight.to(v.dtype).matmul(v)
|
| 81 |
-
out = rearrange(out, "b h s d -> b s (h d)")
|
| 82 |
-
if needs_weights:
|
| 83 |
-
return (out, attn_weight, past_key_value)
|
| 84 |
-
return (out, None, past_key_value)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
| 88 |
-
for tensor in tensors:
|
| 89 |
-
if tensor.dtype not in valid_dtypes:
|
| 90 |
-
raise TypeError(f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.")
|
| 91 |
-
if not tensor.is_cuda:
|
| 92 |
-
raise TypeError(f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).")
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
def flash_attn_fn(
|
| 96 |
-
query,
|
| 97 |
-
key,
|
| 98 |
-
value,
|
| 99 |
-
n_heads,
|
| 100 |
-
past_key_value=None,
|
| 101 |
-
softmax_scale=None,
|
| 102 |
-
attn_bias=None,
|
| 103 |
-
key_padding_mask=None,
|
| 104 |
-
is_causal=False,
|
| 105 |
-
dropout_p=0.0,
|
| 106 |
-
training=False,
|
| 107 |
-
needs_weights=False,
|
| 108 |
-
multiquery=False,
|
| 109 |
-
):
|
| 110 |
-
try:
|
| 111 |
-
from flash_attn import bert_padding, flash_attn_interface
|
| 112 |
-
except:
|
| 113 |
-
raise RuntimeError("Please install flash-attn==1.0.3.post0")
|
| 114 |
-
check_valid_inputs(query, key, value)
|
| 115 |
-
if past_key_value is not None:
|
| 116 |
-
if len(past_key_value) != 0:
|
| 117 |
-
key = torch.cat([past_key_value[0], key], dim=1)
|
| 118 |
-
value = torch.cat([past_key_value[1], value], dim=1)
|
| 119 |
-
past_key_value = (key, value)
|
| 120 |
-
if attn_bias is not None:
|
| 121 |
-
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
| 122 |
-
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
| 123 |
-
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 124 |
-
if attn_bias is not None:
|
| 125 |
-
raise NotImplementedError(f"attn_bias not implemented for flash attn.")
|
| 126 |
-
(batch_size, seqlen) = query.shape[:2]
|
| 127 |
-
if key_padding_mask is None:
|
| 128 |
-
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
|
| 129 |
-
query_padding_mask = key_padding_mask[:, -query.size(1) :]
|
| 130 |
-
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
|
| 131 |
-
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
|
| 132 |
-
(key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
|
| 133 |
-
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
|
| 134 |
-
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
|
| 135 |
-
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads)
|
| 136 |
-
if multiquery:
|
| 137 |
-
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
|
| 138 |
-
value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
|
| 139 |
-
dropout_p = dropout_p if training else 0.0
|
| 140 |
-
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 141 |
-
output_unpad = flash_attn_interface.flash_attn_unpadded_func(
|
| 142 |
-
query_unpad,
|
| 143 |
-
key_unpad,
|
| 144 |
-
value_unpad,
|
| 145 |
-
cu_seqlens_q,
|
| 146 |
-
cu_seqlens_k,
|
| 147 |
-
max_seqlen_q,
|
| 148 |
-
max_seqlen_k,
|
| 149 |
-
dropout_p,
|
| 150 |
-
softmax_scale=softmax_scale,
|
| 151 |
-
causal=reset_is_causal,
|
| 152 |
-
return_attn_probs=needs_weights,
|
| 153 |
-
)
|
| 154 |
-
output = bert_padding.pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen)
|
| 155 |
-
return (output, None, past_key_value)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
def triton_flash_attn_fn(
|
| 159 |
-
query,
|
| 160 |
-
key,
|
| 161 |
-
value,
|
| 162 |
-
n_heads,
|
| 163 |
-
past_key_value=None,
|
| 164 |
-
softmax_scale=None,
|
| 165 |
-
attn_bias=None,
|
| 166 |
-
key_padding_mask=None,
|
| 167 |
-
is_causal=False,
|
| 168 |
-
dropout_p=0.0,
|
| 169 |
-
training=False,
|
| 170 |
-
needs_weights=False,
|
| 171 |
-
multiquery=False,
|
| 172 |
-
):
|
| 173 |
-
try:
|
| 174 |
-
from .flash_attn_triton import flash_attn_func
|
| 175 |
-
except:
|
| 176 |
-
_installed = False
|
| 177 |
-
if version.parse(torch.__version__) < version.parse("2.0.0"):
|
| 178 |
-
_installed = True
|
| 179 |
-
try:
|
| 180 |
-
from flash_attn.flash_attn_triton import flash_attn_func
|
| 181 |
-
except:
|
| 182 |
-
_installed = False
|
| 183 |
-
if not _installed:
|
| 184 |
-
raise RuntimeError(
|
| 185 |
-
"Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
|
| 186 |
-
)
|
| 187 |
-
check_valid_inputs(query, key, value)
|
| 188 |
-
if past_key_value is not None:
|
| 189 |
-
if len(past_key_value) != 0:
|
| 190 |
-
key = torch.cat([past_key_value[0], key], dim=1)
|
| 191 |
-
value = torch.cat([past_key_value[1], value], dim=1)
|
| 192 |
-
past_key_value = (key, value)
|
| 193 |
-
if attn_bias is not None:
|
| 194 |
-
_s_q = max(0, attn_bias.size(2) - query.size(1))
|
| 195 |
-
_s_k = max(0, attn_bias.size(3) - key.size(1))
|
| 196 |
-
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
|
| 197 |
-
if dropout_p:
|
| 198 |
-
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
|
| 199 |
-
if needs_weights:
|
| 200 |
-
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
|
| 201 |
-
if key_padding_mask is not None:
|
| 202 |
-
warnings.warn(
|
| 203 |
-
"Propagating key_padding_mask to the attention module "
|
| 204 |
-
+ "and applying it within the attention module can cause "
|
| 205 |
-
+ "unnecessary computation/memory usage. Consider integrating "
|
| 206 |
-
+ "into attn_bias once and passing that to each attention "
|
| 207 |
-
+ "module instead."
|
| 208 |
-
)
|
| 209 |
-
(b_size, s_k) = key_padding_mask.shape[:2]
|
| 210 |
-
if attn_bias is None:
|
| 211 |
-
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
|
| 212 |
-
attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
|
| 213 |
-
query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
|
| 214 |
-
key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
| 215 |
-
value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
|
| 216 |
-
if multiquery:
|
| 217 |
-
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
| 218 |
-
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
| 219 |
-
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 220 |
-
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
| 221 |
-
output = attn_output.view(*attn_output.shape[:2], -1)
|
| 222 |
-
return (output, None, past_key_value)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
class MultiheadAttention(nn.Module):
|
| 226 |
-
"""Multi-head self attention.
|
| 227 |
-
|
| 228 |
-
Using torch or triton attention implemetation enables user to also use
|
| 229 |
-
additive bias.
|
| 230 |
-
"""
|
| 231 |
-
|
| 232 |
-
def __init__(
|
| 233 |
-
self,
|
| 234 |
-
d_model: int,
|
| 235 |
-
n_heads: int,
|
| 236 |
-
attn_impl: str = "triton",
|
| 237 |
-
clip_qkv: Optional[float] = None,
|
| 238 |
-
qk_ln: bool = False,
|
| 239 |
-
softmax_scale: Optional[float] = None,
|
| 240 |
-
attn_pdrop: float = 0.0,
|
| 241 |
-
low_precision_layernorm: bool = False,
|
| 242 |
-
verbose: int = 0,
|
| 243 |
-
device: Optional[str] = None,
|
| 244 |
-
):
|
| 245 |
-
super().__init__()
|
| 246 |
-
self.attn_impl = attn_impl
|
| 247 |
-
self.clip_qkv = clip_qkv
|
| 248 |
-
self.qk_ln = qk_ln
|
| 249 |
-
self.d_model = d_model
|
| 250 |
-
self.n_heads = n_heads
|
| 251 |
-
self.softmax_scale = softmax_scale
|
| 252 |
-
if self.softmax_scale is None:
|
| 253 |
-
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
| 254 |
-
self.attn_dropout_p = attn_pdrop
|
| 255 |
-
self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
|
| 256 |
-
fuse_splits = (d_model, 2 * d_model)
|
| 257 |
-
self.Wqkv._fused = (0, fuse_splits)
|
| 258 |
-
if self.qk_ln:
|
| 259 |
-
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 260 |
-
self.q_ln = layernorm_class(self.d_model, device=device)
|
| 261 |
-
self.k_ln = layernorm_class(self.d_model, device=device)
|
| 262 |
-
if self.attn_impl == "flash":
|
| 263 |
-
self.attn_fn = flash_attn_fn
|
| 264 |
-
elif self.attn_impl == "triton":
|
| 265 |
-
self.attn_fn = triton_flash_attn_fn
|
| 266 |
-
if verbose:
|
| 267 |
-
warnings.warn(
|
| 268 |
-
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
| 269 |
-
+ "it uses more memory. When training larger models this can trigger "
|
| 270 |
-
+ "alloc retries which hurts performance. If encountered, we recommend "
|
| 271 |
-
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
| 272 |
-
)
|
| 273 |
-
elif self.attn_impl == "torch":
|
| 274 |
-
self.attn_fn = scaled_multihead_dot_product_attention
|
| 275 |
-
if torch.cuda.is_available() and verbose:
|
| 276 |
-
warnings.warn(
|
| 277 |
-
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
| 278 |
-
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
| 279 |
-
+ "we recommend using `attn_impl: triton`."
|
| 280 |
-
)
|
| 281 |
-
else:
|
| 282 |
-
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 283 |
-
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
| 284 |
-
self.out_proj._is_residual = True
|
| 285 |
-
|
| 286 |
-
def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
|
| 287 |
-
qkv = self.Wqkv(x)
|
| 288 |
-
if self.clip_qkv:
|
| 289 |
-
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 290 |
-
(query, key, value) = qkv.chunk(3, dim=2)
|
| 291 |
-
key_padding_mask = attention_mask
|
| 292 |
-
if self.qk_ln:
|
| 293 |
-
dtype = query.dtype
|
| 294 |
-
query = self.q_ln(query).to(dtype)
|
| 295 |
-
key = self.k_ln(key).to(dtype)
|
| 296 |
-
(context, attn_weights, past_key_value) = self.attn_fn(
|
| 297 |
-
query,
|
| 298 |
-
key,
|
| 299 |
-
value,
|
| 300 |
-
self.n_heads,
|
| 301 |
-
past_key_value=past_key_value,
|
| 302 |
-
softmax_scale=self.softmax_scale,
|
| 303 |
-
attn_bias=attn_bias,
|
| 304 |
-
key_padding_mask=key_padding_mask,
|
| 305 |
-
is_causal=is_causal,
|
| 306 |
-
dropout_p=self.attn_dropout_p,
|
| 307 |
-
training=self.training,
|
| 308 |
-
needs_weights=needs_weights,
|
| 309 |
-
)
|
| 310 |
-
return (self.out_proj(context), attn_weights, past_key_value)
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
class MultiQueryAttention(nn.Module):
|
| 314 |
-
"""Multi-Query self attention.
|
| 315 |
-
|
| 316 |
-
Using torch or triton attention implemetation enables user to also use
|
| 317 |
-
additive bias.
|
| 318 |
-
"""
|
| 319 |
-
|
| 320 |
-
def __init__(
|
| 321 |
-
self,
|
| 322 |
-
d_model: int,
|
| 323 |
-
n_heads: int,
|
| 324 |
-
attn_impl: str = "triton",
|
| 325 |
-
clip_qkv: Optional[float] = None,
|
| 326 |
-
qk_ln: bool = False,
|
| 327 |
-
softmax_scale: Optional[float] = None,
|
| 328 |
-
attn_pdrop: float = 0.0,
|
| 329 |
-
low_precision_layernorm: bool = False,
|
| 330 |
-
verbose: int = 0,
|
| 331 |
-
device: Optional[str] = None,
|
| 332 |
-
):
|
| 333 |
-
super().__init__()
|
| 334 |
-
self.attn_impl = attn_impl
|
| 335 |
-
self.clip_qkv = clip_qkv
|
| 336 |
-
self.qk_ln = qk_ln
|
| 337 |
-
self.d_model = d_model
|
| 338 |
-
self.n_heads = n_heads
|
| 339 |
-
self.head_dim = d_model // n_heads
|
| 340 |
-
self.softmax_scale = softmax_scale
|
| 341 |
-
if self.softmax_scale is None:
|
| 342 |
-
self.softmax_scale = 1 / math.sqrt(self.head_dim)
|
| 343 |
-
self.attn_dropout_p = attn_pdrop
|
| 344 |
-
self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
|
| 345 |
-
fuse_splits = (d_model, d_model + self.head_dim)
|
| 346 |
-
self.Wqkv._fused = (0, fuse_splits)
|
| 347 |
-
if self.qk_ln:
|
| 348 |
-
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
|
| 349 |
-
self.q_ln = layernorm_class(d_model, device=device)
|
| 350 |
-
self.k_ln = layernorm_class(self.head_dim, device=device)
|
| 351 |
-
if self.attn_impl == "flash":
|
| 352 |
-
self.attn_fn = flash_attn_fn
|
| 353 |
-
elif self.attn_impl == "triton":
|
| 354 |
-
self.attn_fn = triton_flash_attn_fn
|
| 355 |
-
if verbose:
|
| 356 |
-
warnings.warn(
|
| 357 |
-
"While `attn_impl: triton` can be faster than `attn_impl: flash` "
|
| 358 |
-
+ "it uses more memory. When training larger models this can trigger "
|
| 359 |
-
+ "alloc retries which hurts performance. If encountered, we recommend "
|
| 360 |
-
+ "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
|
| 361 |
-
)
|
| 362 |
-
elif self.attn_impl == "torch":
|
| 363 |
-
self.attn_fn = scaled_multihead_dot_product_attention
|
| 364 |
-
if torch.cuda.is_available() and verbose:
|
| 365 |
-
warnings.warn(
|
| 366 |
-
"Using `attn_impl: torch`. If your model does not use `alibi` or "
|
| 367 |
-
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
|
| 368 |
-
+ "we recommend using `attn_impl: triton`."
|
| 369 |
-
)
|
| 370 |
-
else:
|
| 371 |
-
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 372 |
-
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
|
| 373 |
-
self.out_proj._is_residual = True
|
| 374 |
-
|
| 375 |
-
def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
|
| 376 |
-
qkv = self.Wqkv(x)
|
| 377 |
-
if self.clip_qkv:
|
| 378 |
-
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
| 379 |
-
(query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
|
| 380 |
-
key_padding_mask = attention_mask
|
| 381 |
-
if self.qk_ln:
|
| 382 |
-
dtype = query.dtype
|
| 383 |
-
query = self.q_ln(query).to(dtype)
|
| 384 |
-
key = self.k_ln(key).to(dtype)
|
| 385 |
-
(context, attn_weights, past_key_value) = self.attn_fn(
|
| 386 |
-
query,
|
| 387 |
-
key,
|
| 388 |
-
value,
|
| 389 |
-
self.n_heads,
|
| 390 |
-
past_key_value=past_key_value,
|
| 391 |
-
softmax_scale=self.softmax_scale,
|
| 392 |
-
attn_bias=attn_bias,
|
| 393 |
-
key_padding_mask=key_padding_mask,
|
| 394 |
-
is_causal=is_causal,
|
| 395 |
-
dropout_p=self.attn_dropout_p,
|
| 396 |
-
training=self.training,
|
| 397 |
-
needs_weights=needs_weights,
|
| 398 |
-
multiquery=True,
|
| 399 |
-
)
|
| 400 |
-
return (self.out_proj(context), attn_weights, past_key_value)
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
|
| 404 |
-
if attn_impl == "flash":
|
| 405 |
-
return None
|
| 406 |
-
elif attn_impl in ["torch", "triton"]:
|
| 407 |
-
if alibi:
|
| 408 |
-
if (prefix_lm or not causal) or use_sequence_id:
|
| 409 |
-
return (1, n_heads, seq_len, seq_len)
|
| 410 |
-
return (1, n_heads, 1, seq_len)
|
| 411 |
-
elif prefix_lm or use_sequence_id:
|
| 412 |
-
return (1, 1, seq_len, seq_len)
|
| 413 |
-
return None
|
| 414 |
-
else:
|
| 415 |
-
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
|
| 419 |
-
if attn_impl == "flash":
|
| 420 |
-
return None
|
| 421 |
-
elif attn_impl in ["torch", "triton"]:
|
| 422 |
-
if alibi:
|
| 423 |
-
(device, dtype) = (attn_bias.device, attn_bias.dtype)
|
| 424 |
-
attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
|
| 425 |
-
return attn_bias
|
| 426 |
-
else:
|
| 427 |
-
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
|
| 431 |
-
_n_heads = 2 ** math.ceil(math.log2(n_heads))
|
| 432 |
-
m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
|
| 433 |
-
m = m.mul(alibi_bias_max / _n_heads)
|
| 434 |
-
slopes = 1.0 / torch.pow(2, m)
|
| 435 |
-
if _n_heads != n_heads:
|
| 436 |
-
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
|
| 437 |
-
return slopes.view(1, n_heads, 1, 1)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
|
| 441 |
-
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
|
| 442 |
-
if full:
|
| 443 |
-
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
|
| 444 |
-
alibi_bias = alibi_bias.abs().mul(-1)
|
| 445 |
-
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
|
| 446 |
-
alibi_bias = alibi_bias * slopes
|
| 447 |
-
return alibi_bias.to(dtype=dtype)
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention, "multiquery_attention": MultiQueryAttention}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/blocks.py
DELETED
|
@@ -1,82 +0,0 @@
|
|
| 1 |
-
"""GPT Blocks used for the GPT Model."""
|
| 2 |
-
from typing import Dict, Optional, Tuple
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
from .attention import ATTN_CLASS_REGISTRY
|
| 6 |
-
from .norm import NORM_CLASS_REGISTRY
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class MPTMLP(nn.Module):
|
| 10 |
-
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str] = None):
|
| 11 |
-
super().__init__()
|
| 12 |
-
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
| 13 |
-
## yh: hard code
|
| 14 |
-
# self.act = nn.GELU(approximate='none')
|
| 15 |
-
self.act = nn.GELU()
|
| 16 |
-
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
| 17 |
-
self.down_proj._is_residual = True
|
| 18 |
-
|
| 19 |
-
def forward(self, x):
|
| 20 |
-
return self.down_proj(self.act(self.up_proj(x)))
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class MPTBlock(nn.Module):
|
| 24 |
-
def __init__(
|
| 25 |
-
self,
|
| 26 |
-
d_model: int,
|
| 27 |
-
n_heads: int,
|
| 28 |
-
expansion_ratio: int,
|
| 29 |
-
attn_config: Dict = {
|
| 30 |
-
"attn_type": "multihead_attention",
|
| 31 |
-
"attn_pdrop": 0.0,
|
| 32 |
-
"attn_impl": "triton",
|
| 33 |
-
"qk_ln": False,
|
| 34 |
-
"clip_qkv": None,
|
| 35 |
-
"softmax_scale": None,
|
| 36 |
-
"prefix_lm": False,
|
| 37 |
-
"attn_uses_sequence_id": False,
|
| 38 |
-
"alibi": False,
|
| 39 |
-
"alibi_bias_max": 8,
|
| 40 |
-
},
|
| 41 |
-
resid_pdrop: float = 0.0,
|
| 42 |
-
norm_type: str = "low_precision_layernorm",
|
| 43 |
-
verbose: int = 0,
|
| 44 |
-
device: Optional[str] = None,
|
| 45 |
-
**kwargs
|
| 46 |
-
):
|
| 47 |
-
del kwargs
|
| 48 |
-
super().__init__()
|
| 49 |
-
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
| 50 |
-
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
|
| 51 |
-
self.norm_1 = norm_class(d_model, device=device)
|
| 52 |
-
self.attn = attn_class(
|
| 53 |
-
attn_impl=attn_config["attn_impl"],
|
| 54 |
-
clip_qkv=attn_config["clip_qkv"],
|
| 55 |
-
qk_ln=attn_config["qk_ln"],
|
| 56 |
-
softmax_scale=attn_config["softmax_scale"],
|
| 57 |
-
attn_pdrop=attn_config["attn_pdrop"],
|
| 58 |
-
d_model=d_model,
|
| 59 |
-
n_heads=n_heads,
|
| 60 |
-
verbose=verbose,
|
| 61 |
-
device=device,
|
| 62 |
-
)
|
| 63 |
-
self.norm_2 = norm_class(d_model, device=device)
|
| 64 |
-
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
|
| 65 |
-
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
| 66 |
-
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
| 67 |
-
|
| 68 |
-
def forward(
|
| 69 |
-
self,
|
| 70 |
-
x: torch.Tensor,
|
| 71 |
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 72 |
-
attn_bias: Optional[torch.Tensor] = None,
|
| 73 |
-
attention_mask: Optional[torch.ByteTensor] = None,
|
| 74 |
-
is_causal: bool = True,
|
| 75 |
-
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
| 76 |
-
a = self.norm_1(x)
|
| 77 |
-
(b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
|
| 78 |
-
x = x + self.resid_attn_dropout(b)
|
| 79 |
-
m = self.norm_2(x)
|
| 80 |
-
n = self.ffn(m)
|
| 81 |
-
x = x + self.resid_ffn_dropout(n)
|
| 82 |
-
return (x, attn_weights, past_key_value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/configuration_mpt.py
DELETED
|
@@ -1,161 +0,0 @@
|
|
| 1 |
-
"""A HuggingFace-style model configuration."""
|
| 2 |
-
from typing import Dict, Optional, Union
|
| 3 |
-
from transformers import PretrainedConfig
|
| 4 |
-
|
| 5 |
-
attn_config_defaults: Dict = {
|
| 6 |
-
"attn_type": "multihead_attention",
|
| 7 |
-
"attn_pdrop": 0.0,
|
| 8 |
-
"attn_impl": "triton",
|
| 9 |
-
"qk_ln": False,
|
| 10 |
-
"clip_qkv": None,
|
| 11 |
-
"softmax_scale": None,
|
| 12 |
-
"prefix_lm": False,
|
| 13 |
-
"attn_uses_sequence_id": False,
|
| 14 |
-
"alibi": False,
|
| 15 |
-
"alibi_bias_max": 8,
|
| 16 |
-
}
|
| 17 |
-
init_config_defaults: Dict = {
|
| 18 |
-
"name": "kaiming_normal_",
|
| 19 |
-
"fan_mode": "fan_in",
|
| 20 |
-
"init_nonlinearity": "relu",
|
| 21 |
-
"init_div_is_residual": True,
|
| 22 |
-
"emb_init_std": None,
|
| 23 |
-
"emb_init_uniform_lim": None,
|
| 24 |
-
"init_std": None,
|
| 25 |
-
"init_gain": 0.0,
|
| 26 |
-
}
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class MPTConfig(PretrainedConfig):
|
| 30 |
-
model_type = "mpt"
|
| 31 |
-
|
| 32 |
-
def __init__(
|
| 33 |
-
self,
|
| 34 |
-
d_model: int = 2048,
|
| 35 |
-
n_heads: int = 16,
|
| 36 |
-
n_layers: int = 24,
|
| 37 |
-
expansion_ratio: int = 4,
|
| 38 |
-
max_seq_len: int = 2048,
|
| 39 |
-
vocab_size: int = 50368,
|
| 40 |
-
resid_pdrop: float = 0.0,
|
| 41 |
-
emb_pdrop: float = 0.0,
|
| 42 |
-
learned_pos_emb: bool = True,
|
| 43 |
-
attn_config: Dict = attn_config_defaults,
|
| 44 |
-
init_device: str = "cpu",
|
| 45 |
-
logit_scale: Optional[Union[float, str]] = None,
|
| 46 |
-
no_bias: bool = False,
|
| 47 |
-
verbose: int = 0,
|
| 48 |
-
embedding_fraction: float = 1.0,
|
| 49 |
-
norm_type: str = "low_precision_layernorm",
|
| 50 |
-
use_cache: bool = False,
|
| 51 |
-
init_config: Dict = init_config_defaults,
|
| 52 |
-
**kwargs,
|
| 53 |
-
):
|
| 54 |
-
"""The MPT configuration class.
|
| 55 |
-
|
| 56 |
-
Args:
|
| 57 |
-
d_model (int): The size of the embedding dimension of the model.
|
| 58 |
-
n_heads (int): The number of attention heads.
|
| 59 |
-
n_layers (int): The number of layers in the model.
|
| 60 |
-
expansion_ratio (int): The ratio of the up/down scale in the MLP.
|
| 61 |
-
max_seq_len (int): The maximum sequence length of the model.
|
| 62 |
-
vocab_size (int): The size of the vocabulary.
|
| 63 |
-
resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
|
| 64 |
-
emb_pdrop (float): The dropout probability for the embedding layer.
|
| 65 |
-
learned_pos_emb (bool): Whether to use learned positional embeddings
|
| 66 |
-
attn_config (Dict): A dictionary used to configure the model's attention module:
|
| 67 |
-
attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
|
| 68 |
-
attn_pdrop (float): The dropout probability for the attention layers.
|
| 69 |
-
attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
|
| 70 |
-
qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
|
| 71 |
-
clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
|
| 72 |
-
this value.
|
| 73 |
-
softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
|
| 74 |
-
use the default scale of ``1/sqrt(d_keys)``.
|
| 75 |
-
prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
|
| 76 |
-
extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
|
| 77 |
-
can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
|
| 78 |
-
attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
|
| 79 |
-
When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
|
| 80 |
-
which sub-sequence each token belongs to.
|
| 81 |
-
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
|
| 82 |
-
alibi (bool): Whether to use the alibi bias instead of position embeddings.
|
| 83 |
-
alibi_bias_max (int): The maximum value of the alibi bias.
|
| 84 |
-
init_device (str): The device to use for parameter initialization.
|
| 85 |
-
logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
|
| 86 |
-
no_bias (bool): Whether to use bias in all layers.
|
| 87 |
-
verbose (int): The verbosity level. 0 is silent.
|
| 88 |
-
embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
|
| 89 |
-
norm_type (str): choose type of norm to use
|
| 90 |
-
multiquery_attention (bool): Whether to use multiquery attention implementation.
|
| 91 |
-
use_cache (bool): Whether or not the model should return the last key/values attentions
|
| 92 |
-
init_config (Dict): A dictionary used to configure the model initialization:
|
| 93 |
-
init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
|
| 94 |
-
'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
|
| 95 |
-
'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
|
| 96 |
-
init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
|
| 97 |
-
emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
|
| 98 |
-
emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
|
| 99 |
-
used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
|
| 100 |
-
init_std (float): The standard deviation of the normal distribution used to initialize the model,
|
| 101 |
-
if using the baseline_ parameter initialization scheme.
|
| 102 |
-
init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
|
| 103 |
-
fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
|
| 104 |
-
init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
|
| 105 |
-
---
|
| 106 |
-
See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
|
| 107 |
-
"""
|
| 108 |
-
self.d_model = d_model
|
| 109 |
-
self.n_heads = n_heads
|
| 110 |
-
self.n_layers = n_layers
|
| 111 |
-
self.expansion_ratio = expansion_ratio
|
| 112 |
-
self.max_seq_len = max_seq_len
|
| 113 |
-
self.vocab_size = vocab_size
|
| 114 |
-
self.resid_pdrop = resid_pdrop
|
| 115 |
-
self.emb_pdrop = emb_pdrop
|
| 116 |
-
self.learned_pos_emb = learned_pos_emb
|
| 117 |
-
self.attn_config = attn_config
|
| 118 |
-
self.init_device = init_device
|
| 119 |
-
self.logit_scale = logit_scale
|
| 120 |
-
self.no_bias = no_bias
|
| 121 |
-
self.verbose = verbose
|
| 122 |
-
self.embedding_fraction = embedding_fraction
|
| 123 |
-
self.norm_type = norm_type
|
| 124 |
-
self.use_cache = use_cache
|
| 125 |
-
self.init_config = init_config
|
| 126 |
-
if "name" in kwargs:
|
| 127 |
-
del kwargs["name"]
|
| 128 |
-
if "loss_fn" in kwargs:
|
| 129 |
-
del kwargs["loss_fn"]
|
| 130 |
-
super().__init__(**kwargs)
|
| 131 |
-
self._validate_config()
|
| 132 |
-
|
| 133 |
-
def _set_config_defaults(self, config, config_defaults):
|
| 134 |
-
for k, v in config_defaults.items():
|
| 135 |
-
if k not in config:
|
| 136 |
-
config[k] = v
|
| 137 |
-
return config
|
| 138 |
-
|
| 139 |
-
def _validate_config(self):
|
| 140 |
-
self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
|
| 141 |
-
self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
|
| 142 |
-
if self.d_model % self.n_heads != 0:
|
| 143 |
-
raise ValueError("d_model must be divisible by n_heads")
|
| 144 |
-
if any((prob < 0 or prob > 1 for prob in [self.attn_config["attn_pdrop"], self.resid_pdrop, self.emb_pdrop])):
|
| 145 |
-
raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
|
| 146 |
-
if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
|
| 147 |
-
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
|
| 148 |
-
if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 149 |
-
raise NotImplementedError("prefix_lm only implemented with torch and triton attention.")
|
| 150 |
-
if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 151 |
-
raise NotImplementedError("alibi only implemented with torch and triton attention.")
|
| 152 |
-
if self.attn_config["attn_uses_sequence_id"] and self.attn_config["attn_impl"] not in ["torch", "triton"]:
|
| 153 |
-
raise NotImplementedError("attn_uses_sequence_id only implemented with torch and triton attention.")
|
| 154 |
-
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
|
| 155 |
-
raise ValueError("model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!")
|
| 156 |
-
if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
|
| 157 |
-
raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
| 158 |
-
if self.init_config.get("name", None) is None:
|
| 159 |
-
raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
|
| 160 |
-
if not self.learned_pos_emb and (not self.attn_config["alibi"]):
|
| 161 |
-
raise ValueError(f"Positional information must be provided to the model using either learned_pos_emb or alibi.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/custom_embedding.py
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from torch import Tensor
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class SharedEmbedding(nn.Embedding):
|
| 8 |
-
def forward(self, input: Tensor, unembed: bool = False) -> Tensor:
|
| 9 |
-
if unembed:
|
| 10 |
-
return F.linear(input, self.weight)
|
| 11 |
-
return super().forward(input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/flash_attn_triton.py
DELETED
|
@@ -1,841 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copied from https://github.com/HazyResearch/flash-attention/blob/eff9fe6b8076df59d64d7a3f464696738a3c7c24/flash_attn/flash_attn_triton.py
|
| 3 |
-
update imports to use 'triton_pre_mlir'
|
| 4 |
-
|
| 5 |
-
*Experimental* implementation of FlashAttention in Triton.
|
| 6 |
-
Tested with triton==2.0.0.dev20221202.
|
| 7 |
-
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
|
| 8 |
-
other than 64:
|
| 9 |
-
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
|
| 10 |
-
We'll update this implementation with the new Triton backend once this is fixed.
|
| 11 |
-
|
| 12 |
-
We use the FlashAttention implementation from Phil Tillet a starting point.
|
| 13 |
-
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
| 14 |
-
|
| 15 |
-
Changes:
|
| 16 |
-
- Implement both causal and non-causal attention.
|
| 17 |
-
- Implement both self-attention and cross-attention.
|
| 18 |
-
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
| 19 |
-
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
|
| 20 |
-
- Support attention bias.
|
| 21 |
-
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
| 22 |
-
- Make the backward for d=128 much faster by reducing register spilling.
|
| 23 |
-
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
| 24 |
-
small batch size * nheads.
|
| 25 |
-
|
| 26 |
-
Caution:
|
| 27 |
-
- This is an *experimental* implementation. The forward pass should be quite robust but
|
| 28 |
-
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
|
| 29 |
-
- This implementation has only been tested on A100.
|
| 30 |
-
- If you plan to use headdim other than 64 and 128, you should test for race conditions
|
| 31 |
-
(due to the Triton compiler), as done in tests/test_flash_attn.py
|
| 32 |
-
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
|
| 33 |
-
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
|
| 34 |
-
that there are none left for other head dimensions.
|
| 35 |
-
|
| 36 |
-
Differences between this Triton version and the CUDA version:
|
| 37 |
-
- Triton version doesn't support dropout.
|
| 38 |
-
- Triton forward is generally faster than CUDA forward, while Triton backward is
|
| 39 |
-
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
|
| 40 |
-
than CUDA forward + backward.
|
| 41 |
-
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
| 42 |
-
- Triton version supports attention bias, while CUDA version doesn't.
|
| 43 |
-
"""
|
| 44 |
-
import math
|
| 45 |
-
import torch
|
| 46 |
-
import triton_pre_mlir as triton
|
| 47 |
-
import triton_pre_mlir.language as tl
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
@triton.heuristics(
|
| 51 |
-
{
|
| 52 |
-
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 53 |
-
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 54 |
-
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 55 |
-
}
|
| 56 |
-
)
|
| 57 |
-
@triton.jit
|
| 58 |
-
def _fwd_kernel(
|
| 59 |
-
Q,
|
| 60 |
-
K,
|
| 61 |
-
V,
|
| 62 |
-
Bias,
|
| 63 |
-
Out,
|
| 64 |
-
Lse,
|
| 65 |
-
TMP,
|
| 66 |
-
softmax_scale,
|
| 67 |
-
stride_qb,
|
| 68 |
-
stride_qh,
|
| 69 |
-
stride_qm,
|
| 70 |
-
stride_kb,
|
| 71 |
-
stride_kh,
|
| 72 |
-
stride_kn,
|
| 73 |
-
stride_vb,
|
| 74 |
-
stride_vh,
|
| 75 |
-
stride_vn,
|
| 76 |
-
stride_bb,
|
| 77 |
-
stride_bh,
|
| 78 |
-
stride_bm,
|
| 79 |
-
stride_ob,
|
| 80 |
-
stride_oh,
|
| 81 |
-
stride_om,
|
| 82 |
-
nheads,
|
| 83 |
-
seqlen_q,
|
| 84 |
-
seqlen_k,
|
| 85 |
-
seqlen_q_rounded,
|
| 86 |
-
headdim,
|
| 87 |
-
CACHE_KEY_SEQLEN_Q,
|
| 88 |
-
CACHE_KEY_SEQLEN_K,
|
| 89 |
-
BIAS_TYPE: tl.constexpr,
|
| 90 |
-
IS_CAUSAL: tl.constexpr,
|
| 91 |
-
BLOCK_HEADDIM: tl.constexpr,
|
| 92 |
-
EVEN_M: tl.constexpr,
|
| 93 |
-
EVEN_N: tl.constexpr,
|
| 94 |
-
EVEN_HEADDIM: tl.constexpr,
|
| 95 |
-
BLOCK_M: tl.constexpr,
|
| 96 |
-
BLOCK_N: tl.constexpr,
|
| 97 |
-
):
|
| 98 |
-
start_m = tl.program_id(0)
|
| 99 |
-
off_hb = tl.program_id(1)
|
| 100 |
-
off_b = off_hb // nheads
|
| 101 |
-
off_h = off_hb % nheads
|
| 102 |
-
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 103 |
-
offs_n = tl.arange(0, BLOCK_N)
|
| 104 |
-
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 105 |
-
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
| 106 |
-
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 107 |
-
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 108 |
-
if BIAS_TYPE == "vector":
|
| 109 |
-
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
|
| 110 |
-
elif BIAS_TYPE == "matrix":
|
| 111 |
-
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
|
| 112 |
-
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
| 113 |
-
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 114 |
-
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
| 115 |
-
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
| 116 |
-
if EVEN_M & EVEN_N:
|
| 117 |
-
if EVEN_HEADDIM:
|
| 118 |
-
q = tl.load(q_ptrs)
|
| 119 |
-
else:
|
| 120 |
-
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 121 |
-
elif EVEN_HEADDIM:
|
| 122 |
-
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
| 123 |
-
else:
|
| 124 |
-
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 125 |
-
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
| 126 |
-
for start_n in range(0, end_n, BLOCK_N):
|
| 127 |
-
start_n = tl.multiple_of(start_n, BLOCK_N)
|
| 128 |
-
if EVEN_N & EVEN_M:
|
| 129 |
-
if EVEN_HEADDIM:
|
| 130 |
-
k = tl.load(k_ptrs + start_n * stride_kn)
|
| 131 |
-
else:
|
| 132 |
-
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
| 133 |
-
elif EVEN_HEADDIM:
|
| 134 |
-
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
|
| 135 |
-
else:
|
| 136 |
-
k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 137 |
-
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 138 |
-
qk += tl.dot(q, k, trans_b=True)
|
| 139 |
-
if not EVEN_N:
|
| 140 |
-
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
| 141 |
-
if IS_CAUSAL:
|
| 142 |
-
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
| 143 |
-
if BIAS_TYPE != "none":
|
| 144 |
-
if BIAS_TYPE == "vector":
|
| 145 |
-
if EVEN_N:
|
| 146 |
-
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
| 147 |
-
else:
|
| 148 |
-
bias = tl.load(b_ptrs + start_n, mask=start_n + offs_n < seqlen_k, other=0.0).to(tl.float32)
|
| 149 |
-
bias = bias[None, :]
|
| 150 |
-
elif BIAS_TYPE == "matrix":
|
| 151 |
-
if EVEN_M & EVEN_N:
|
| 152 |
-
bias = tl.load(b_ptrs + start_n).to(tl.float32)
|
| 153 |
-
else:
|
| 154 |
-
bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)
|
| 155 |
-
qk = qk * softmax_scale + bias
|
| 156 |
-
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
|
| 157 |
-
p = tl.exp(qk - m_ij[:, None])
|
| 158 |
-
else:
|
| 159 |
-
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
| 160 |
-
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
| 161 |
-
l_ij = tl.sum(p, 1)
|
| 162 |
-
acc_o_scale = tl.exp(m_i - m_ij)
|
| 163 |
-
tl.store(t_ptrs, acc_o_scale)
|
| 164 |
-
acc_o_scale = tl.load(t_ptrs)
|
| 165 |
-
acc_o = acc_o * acc_o_scale[:, None]
|
| 166 |
-
if EVEN_N & EVEN_M:
|
| 167 |
-
if EVEN_HEADDIM:
|
| 168 |
-
v = tl.load(v_ptrs + start_n * stride_vn)
|
| 169 |
-
else:
|
| 170 |
-
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
| 171 |
-
elif EVEN_HEADDIM:
|
| 172 |
-
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
|
| 173 |
-
else:
|
| 174 |
-
v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 175 |
-
p = p.to(v.dtype)
|
| 176 |
-
acc_o += tl.dot(p, v)
|
| 177 |
-
m_i = m_ij
|
| 178 |
-
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
| 179 |
-
lse_i = m_ij + tl.log(l_i_new)
|
| 180 |
-
o_scale = tl.exp(m_i - lse_i)
|
| 181 |
-
tl.store(t_ptrs, o_scale)
|
| 182 |
-
o_scale = tl.load(t_ptrs)
|
| 183 |
-
acc_o = acc_o * o_scale[:, None]
|
| 184 |
-
start_m = tl.program_id(0)
|
| 185 |
-
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 186 |
-
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
| 187 |
-
tl.store(lse_ptrs, lse_i)
|
| 188 |
-
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 189 |
-
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :])
|
| 190 |
-
if EVEN_M:
|
| 191 |
-
if EVEN_HEADDIM:
|
| 192 |
-
tl.store(out_ptrs, acc_o)
|
| 193 |
-
else:
|
| 194 |
-
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
| 195 |
-
elif EVEN_HEADDIM:
|
| 196 |
-
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
| 197 |
-
else:
|
| 198 |
-
tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
@triton.jit
|
| 202 |
-
def _bwd_preprocess_do_o_dot(
|
| 203 |
-
Out,
|
| 204 |
-
DO,
|
| 205 |
-
Delta,
|
| 206 |
-
stride_ob,
|
| 207 |
-
stride_oh,
|
| 208 |
-
stride_om,
|
| 209 |
-
stride_dob,
|
| 210 |
-
stride_doh,
|
| 211 |
-
stride_dom,
|
| 212 |
-
nheads,
|
| 213 |
-
seqlen_q,
|
| 214 |
-
seqlen_q_rounded,
|
| 215 |
-
headdim,
|
| 216 |
-
BLOCK_M: tl.constexpr,
|
| 217 |
-
BLOCK_HEADDIM: tl.constexpr,
|
| 218 |
-
):
|
| 219 |
-
start_m = tl.program_id(0)
|
| 220 |
-
off_hb = tl.program_id(1)
|
| 221 |
-
off_b = off_hb // nheads
|
| 222 |
-
off_h = off_hb % nheads
|
| 223 |
-
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 224 |
-
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 225 |
-
o = tl.load(
|
| 226 |
-
Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
|
| 227 |
-
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 228 |
-
other=0.0,
|
| 229 |
-
).to(tl.float32)
|
| 230 |
-
do = tl.load(
|
| 231 |
-
DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
|
| 232 |
-
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
|
| 233 |
-
other=0.0,
|
| 234 |
-
).to(tl.float32)
|
| 235 |
-
delta = tl.sum(o * do, axis=1)
|
| 236 |
-
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
@triton.jit
|
| 240 |
-
def _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr):
|
| 241 |
-
if EVEN_N & EVEN_M:
|
| 242 |
-
if EVEN_HEADDIM:
|
| 243 |
-
tl.store(dv_ptrs, dv)
|
| 244 |
-
tl.store(dk_ptrs, dk)
|
| 245 |
-
else:
|
| 246 |
-
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
| 247 |
-
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
| 248 |
-
elif EVEN_HEADDIM:
|
| 249 |
-
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
| 250 |
-
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
| 251 |
-
else:
|
| 252 |
-
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 253 |
-
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
@triton.jit
|
| 257 |
-
def _bwd_kernel_one_col_block(
|
| 258 |
-
start_n,
|
| 259 |
-
Q,
|
| 260 |
-
K,
|
| 261 |
-
V,
|
| 262 |
-
Bias,
|
| 263 |
-
DO,
|
| 264 |
-
DQ,
|
| 265 |
-
DK,
|
| 266 |
-
DV,
|
| 267 |
-
LSE,
|
| 268 |
-
D,
|
| 269 |
-
softmax_scale,
|
| 270 |
-
stride_qm,
|
| 271 |
-
stride_kn,
|
| 272 |
-
stride_vn,
|
| 273 |
-
stride_bm,
|
| 274 |
-
stride_dom,
|
| 275 |
-
stride_dqm,
|
| 276 |
-
stride_dkn,
|
| 277 |
-
stride_dvn,
|
| 278 |
-
seqlen_q,
|
| 279 |
-
seqlen_k,
|
| 280 |
-
headdim,
|
| 281 |
-
ATOMIC_ADD: tl.constexpr,
|
| 282 |
-
BIAS_TYPE: tl.constexpr,
|
| 283 |
-
IS_CAUSAL: tl.constexpr,
|
| 284 |
-
BLOCK_HEADDIM: tl.constexpr,
|
| 285 |
-
EVEN_M: tl.constexpr,
|
| 286 |
-
EVEN_N: tl.constexpr,
|
| 287 |
-
EVEN_HEADDIM: tl.constexpr,
|
| 288 |
-
BLOCK_M: tl.constexpr,
|
| 289 |
-
BLOCK_N: tl.constexpr,
|
| 290 |
-
):
|
| 291 |
-
begin_m = 0 if not IS_CAUSAL else start_n * BLOCK_N // BLOCK_M * BLOCK_M
|
| 292 |
-
offs_qm = begin_m + tl.arange(0, BLOCK_M)
|
| 293 |
-
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 294 |
-
offs_m = tl.arange(0, BLOCK_M)
|
| 295 |
-
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
| 296 |
-
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
|
| 297 |
-
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
| 298 |
-
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
| 299 |
-
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
|
| 300 |
-
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
|
| 301 |
-
if BIAS_TYPE == "vector":
|
| 302 |
-
b_ptrs = Bias + offs_n
|
| 303 |
-
elif BIAS_TYPE == "matrix":
|
| 304 |
-
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
|
| 305 |
-
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 306 |
-
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
| 307 |
-
if begin_m >= seqlen_q:
|
| 308 |
-
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
| 309 |
-
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
| 310 |
-
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
|
| 311 |
-
return
|
| 312 |
-
if EVEN_N & EVEN_M:
|
| 313 |
-
if EVEN_HEADDIM:
|
| 314 |
-
k = tl.load(k_ptrs)
|
| 315 |
-
v = tl.load(v_ptrs)
|
| 316 |
-
else:
|
| 317 |
-
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 318 |
-
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
| 319 |
-
elif EVEN_HEADDIM:
|
| 320 |
-
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 321 |
-
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
|
| 322 |
-
else:
|
| 323 |
-
k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 324 |
-
v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
|
| 325 |
-
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
|
| 326 |
-
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
|
| 327 |
-
start_m = tl.multiple_of(start_m, BLOCK_M)
|
| 328 |
-
offs_m_curr = start_m + offs_m
|
| 329 |
-
if EVEN_M & EVEN_HEADDIM:
|
| 330 |
-
q = tl.load(q_ptrs)
|
| 331 |
-
elif EVEN_HEADDIM:
|
| 332 |
-
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
| 333 |
-
else:
|
| 334 |
-
q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 335 |
-
qk = tl.dot(q, k, trans_b=True)
|
| 336 |
-
if not EVEN_N:
|
| 337 |
-
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
|
| 338 |
-
if IS_CAUSAL:
|
| 339 |
-
qk = tl.where(offs_m_curr[:, None] >= offs_n[None, :], qk, float("-inf"))
|
| 340 |
-
if BIAS_TYPE != "none":
|
| 341 |
-
tl.debug_barrier()
|
| 342 |
-
if BIAS_TYPE == "vector":
|
| 343 |
-
if EVEN_N:
|
| 344 |
-
bias = tl.load(b_ptrs).to(tl.float32)
|
| 345 |
-
else:
|
| 346 |
-
bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
|
| 347 |
-
bias = bias[None, :]
|
| 348 |
-
elif BIAS_TYPE == "matrix":
|
| 349 |
-
if EVEN_M & EVEN_N:
|
| 350 |
-
bias = tl.load(b_ptrs).to(tl.float32)
|
| 351 |
-
else:
|
| 352 |
-
bias = tl.load(b_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0).to(tl.float32)
|
| 353 |
-
qk = qk * softmax_scale + bias
|
| 354 |
-
if not EVEN_M & EVEN_HEADDIM:
|
| 355 |
-
tl.debug_barrier()
|
| 356 |
-
lse_i = tl.load(LSE + offs_m_curr)
|
| 357 |
-
if BIAS_TYPE == "none":
|
| 358 |
-
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
| 359 |
-
else:
|
| 360 |
-
p = tl.exp(qk - lse_i[:, None])
|
| 361 |
-
if EVEN_M & EVEN_HEADDIM:
|
| 362 |
-
do = tl.load(do_ptrs)
|
| 363 |
-
else:
|
| 364 |
-
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
|
| 365 |
-
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
| 366 |
-
if not EVEN_M & EVEN_HEADDIM:
|
| 367 |
-
tl.debug_barrier()
|
| 368 |
-
dp = tl.dot(do, v, trans_b=True)
|
| 369 |
-
if not EVEN_HEADDIM:
|
| 370 |
-
tl.debug_barrier()
|
| 371 |
-
Di = tl.load(D + offs_m_curr)
|
| 372 |
-
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
| 373 |
-
dk += tl.dot(ds, q, trans_a=True)
|
| 374 |
-
if not EVEN_M & EVEN_HEADDIM:
|
| 375 |
-
tl.debug_barrier()
|
| 376 |
-
if not ATOMIC_ADD:
|
| 377 |
-
if EVEN_M & EVEN_HEADDIM:
|
| 378 |
-
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
| 379 |
-
dq += tl.dot(ds, k)
|
| 380 |
-
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
| 381 |
-
elif EVEN_HEADDIM:
|
| 382 |
-
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last")
|
| 383 |
-
dq += tl.dot(ds, k)
|
| 384 |
-
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last")
|
| 385 |
-
else:
|
| 386 |
-
dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy="evict_last")
|
| 387 |
-
dq += tl.dot(ds, k)
|
| 388 |
-
tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy="evict_last")
|
| 389 |
-
else:
|
| 390 |
-
dq = tl.dot(ds, k)
|
| 391 |
-
if EVEN_M & EVEN_HEADDIM:
|
| 392 |
-
tl.atomic_add(dq_ptrs, dq)
|
| 393 |
-
elif EVEN_HEADDIM:
|
| 394 |
-
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
|
| 395 |
-
else:
|
| 396 |
-
tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim))
|
| 397 |
-
dq_ptrs += BLOCK_M * stride_dqm
|
| 398 |
-
q_ptrs += BLOCK_M * stride_qm
|
| 399 |
-
do_ptrs += BLOCK_M * stride_dom
|
| 400 |
-
if BIAS_TYPE == "matrix":
|
| 401 |
-
b_ptrs += BLOCK_M * stride_bm
|
| 402 |
-
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
| 403 |
-
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
| 404 |
-
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
def init_to_zero(name):
|
| 408 |
-
return lambda nargs: nargs[name].zero_()
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
@triton.autotune(
|
| 412 |
-
configs=[
|
| 413 |
-
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ")),
|
| 414 |
-
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero("DQ")),
|
| 415 |
-
],
|
| 416 |
-
key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
|
| 417 |
-
)
|
| 418 |
-
@triton.heuristics(
|
| 419 |
-
{
|
| 420 |
-
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
| 421 |
-
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
| 422 |
-
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
| 423 |
-
}
|
| 424 |
-
)
|
| 425 |
-
@triton.jit
|
| 426 |
-
def _bwd_kernel(
|
| 427 |
-
Q,
|
| 428 |
-
K,
|
| 429 |
-
V,
|
| 430 |
-
Bias,
|
| 431 |
-
DO,
|
| 432 |
-
DQ,
|
| 433 |
-
DK,
|
| 434 |
-
DV,
|
| 435 |
-
LSE,
|
| 436 |
-
D,
|
| 437 |
-
softmax_scale,
|
| 438 |
-
stride_qb,
|
| 439 |
-
stride_qh,
|
| 440 |
-
stride_qm,
|
| 441 |
-
stride_kb,
|
| 442 |
-
stride_kh,
|
| 443 |
-
stride_kn,
|
| 444 |
-
stride_vb,
|
| 445 |
-
stride_vh,
|
| 446 |
-
stride_vn,
|
| 447 |
-
stride_bb,
|
| 448 |
-
stride_bh,
|
| 449 |
-
stride_bm,
|
| 450 |
-
stride_dob,
|
| 451 |
-
stride_doh,
|
| 452 |
-
stride_dom,
|
| 453 |
-
stride_dqb,
|
| 454 |
-
stride_dqh,
|
| 455 |
-
stride_dqm,
|
| 456 |
-
stride_dkb,
|
| 457 |
-
stride_dkh,
|
| 458 |
-
stride_dkn,
|
| 459 |
-
stride_dvb,
|
| 460 |
-
stride_dvh,
|
| 461 |
-
stride_dvn,
|
| 462 |
-
nheads,
|
| 463 |
-
seqlen_q,
|
| 464 |
-
seqlen_k,
|
| 465 |
-
seqlen_q_rounded,
|
| 466 |
-
headdim,
|
| 467 |
-
CACHE_KEY_SEQLEN_Q,
|
| 468 |
-
CACHE_KEY_SEQLEN_K,
|
| 469 |
-
BIAS_TYPE: tl.constexpr,
|
| 470 |
-
IS_CAUSAL: tl.constexpr,
|
| 471 |
-
BLOCK_HEADDIM: tl.constexpr,
|
| 472 |
-
SEQUENCE_PARALLEL: tl.constexpr,
|
| 473 |
-
EVEN_M: tl.constexpr,
|
| 474 |
-
EVEN_N: tl.constexpr,
|
| 475 |
-
EVEN_HEADDIM: tl.constexpr,
|
| 476 |
-
BLOCK_M: tl.constexpr,
|
| 477 |
-
BLOCK_N: tl.constexpr,
|
| 478 |
-
):
|
| 479 |
-
off_hb = tl.program_id(1)
|
| 480 |
-
off_b = off_hb // nheads
|
| 481 |
-
off_h = off_hb % nheads
|
| 482 |
-
Q += off_b * stride_qb + off_h * stride_qh
|
| 483 |
-
K += off_b * stride_kb + off_h * stride_kh
|
| 484 |
-
V += off_b * stride_vb + off_h * stride_vh
|
| 485 |
-
DO += off_b * stride_dob + off_h * stride_doh
|
| 486 |
-
DQ += off_b * stride_dqb + off_h * stride_dqh
|
| 487 |
-
DK += off_b * stride_dkb + off_h * stride_dkh
|
| 488 |
-
DV += off_b * stride_dvb + off_h * stride_dvh
|
| 489 |
-
if BIAS_TYPE != "none":
|
| 490 |
-
Bias += off_b * stride_bb + off_h * stride_bh
|
| 491 |
-
D += off_hb * seqlen_q_rounded
|
| 492 |
-
LSE += off_hb * seqlen_q_rounded
|
| 493 |
-
if not SEQUENCE_PARALLEL:
|
| 494 |
-
num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
|
| 495 |
-
for start_n in range(0, num_block_n):
|
| 496 |
-
_bwd_kernel_one_col_block(
|
| 497 |
-
start_n,
|
| 498 |
-
Q,
|
| 499 |
-
K,
|
| 500 |
-
V,
|
| 501 |
-
Bias,
|
| 502 |
-
DO,
|
| 503 |
-
DQ,
|
| 504 |
-
DK,
|
| 505 |
-
DV,
|
| 506 |
-
LSE,
|
| 507 |
-
D,
|
| 508 |
-
softmax_scale,
|
| 509 |
-
stride_qm,
|
| 510 |
-
stride_kn,
|
| 511 |
-
stride_vn,
|
| 512 |
-
stride_bm,
|
| 513 |
-
stride_dom,
|
| 514 |
-
stride_dqm,
|
| 515 |
-
stride_dkn,
|
| 516 |
-
stride_dvn,
|
| 517 |
-
seqlen_q,
|
| 518 |
-
seqlen_k,
|
| 519 |
-
headdim,
|
| 520 |
-
ATOMIC_ADD=False,
|
| 521 |
-
BIAS_TYPE=BIAS_TYPE,
|
| 522 |
-
IS_CAUSAL=IS_CAUSAL,
|
| 523 |
-
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 524 |
-
EVEN_M=EVEN_M,
|
| 525 |
-
EVEN_N=EVEN_N,
|
| 526 |
-
EVEN_HEADDIM=EVEN_HEADDIM,
|
| 527 |
-
BLOCK_M=BLOCK_M,
|
| 528 |
-
BLOCK_N=BLOCK_N,
|
| 529 |
-
)
|
| 530 |
-
else:
|
| 531 |
-
start_n = tl.program_id(0)
|
| 532 |
-
_bwd_kernel_one_col_block(
|
| 533 |
-
start_n,
|
| 534 |
-
Q,
|
| 535 |
-
K,
|
| 536 |
-
V,
|
| 537 |
-
Bias,
|
| 538 |
-
DO,
|
| 539 |
-
DQ,
|
| 540 |
-
DK,
|
| 541 |
-
DV,
|
| 542 |
-
LSE,
|
| 543 |
-
D,
|
| 544 |
-
softmax_scale,
|
| 545 |
-
stride_qm,
|
| 546 |
-
stride_kn,
|
| 547 |
-
stride_vn,
|
| 548 |
-
stride_bm,
|
| 549 |
-
stride_dom,
|
| 550 |
-
stride_dqm,
|
| 551 |
-
stride_dkn,
|
| 552 |
-
stride_dvn,
|
| 553 |
-
seqlen_q,
|
| 554 |
-
seqlen_k,
|
| 555 |
-
headdim,
|
| 556 |
-
ATOMIC_ADD=True,
|
| 557 |
-
BIAS_TYPE=BIAS_TYPE,
|
| 558 |
-
IS_CAUSAL=IS_CAUSAL,
|
| 559 |
-
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 560 |
-
EVEN_M=EVEN_M,
|
| 561 |
-
EVEN_N=EVEN_N,
|
| 562 |
-
EVEN_HEADDIM=EVEN_HEADDIM,
|
| 563 |
-
BLOCK_M=BLOCK_M,
|
| 564 |
-
BLOCK_N=BLOCK_N,
|
| 565 |
-
)
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
| 569 |
-
(batch, seqlen_q, nheads, d) = q.shape
|
| 570 |
-
(_, seqlen_k, _, _) = k.shape
|
| 571 |
-
assert k.shape == (batch, seqlen_k, nheads, d)
|
| 572 |
-
assert v.shape == (batch, seqlen_k, nheads, d)
|
| 573 |
-
assert d <= 128, "FlashAttention only support head dimensions up to 128"
|
| 574 |
-
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
|
| 575 |
-
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
| 576 |
-
assert q.is_cuda and k.is_cuda and v.is_cuda
|
| 577 |
-
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
| 578 |
-
has_bias = bias is not None
|
| 579 |
-
bias_type = "none"
|
| 580 |
-
if has_bias:
|
| 581 |
-
assert bias.dtype in [q.dtype, torch.float]
|
| 582 |
-
assert bias.is_cuda
|
| 583 |
-
assert bias.dim() == 4
|
| 584 |
-
if bias.stride(-1) != 1:
|
| 585 |
-
bias = bias.contiguous()
|
| 586 |
-
if bias.shape[2:] == (1, seqlen_k):
|
| 587 |
-
bias_type = "vector"
|
| 588 |
-
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
| 589 |
-
bias_type = "matrix"
|
| 590 |
-
else:
|
| 591 |
-
raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
|
| 592 |
-
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
| 593 |
-
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
| 594 |
-
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
| 595 |
-
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
| 596 |
-
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
| 597 |
-
o = torch.empty_like(q)
|
| 598 |
-
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
| 599 |
-
BLOCK = 128
|
| 600 |
-
num_warps = 4 if d <= 64 else 8
|
| 601 |
-
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
| 602 |
-
_fwd_kernel[grid](
|
| 603 |
-
q,
|
| 604 |
-
k,
|
| 605 |
-
v,
|
| 606 |
-
bias,
|
| 607 |
-
o,
|
| 608 |
-
lse,
|
| 609 |
-
tmp,
|
| 610 |
-
softmax_scale,
|
| 611 |
-
q.stride(0),
|
| 612 |
-
q.stride(2),
|
| 613 |
-
q.stride(1),
|
| 614 |
-
k.stride(0),
|
| 615 |
-
k.stride(2),
|
| 616 |
-
k.stride(1),
|
| 617 |
-
v.stride(0),
|
| 618 |
-
v.stride(2),
|
| 619 |
-
v.stride(1),
|
| 620 |
-
*bias_strides,
|
| 621 |
-
o.stride(0),
|
| 622 |
-
o.stride(2),
|
| 623 |
-
o.stride(1),
|
| 624 |
-
nheads,
|
| 625 |
-
seqlen_q,
|
| 626 |
-
seqlen_k,
|
| 627 |
-
seqlen_q_rounded,
|
| 628 |
-
d,
|
| 629 |
-
seqlen_q // 32,
|
| 630 |
-
seqlen_k // 32,
|
| 631 |
-
bias_type,
|
| 632 |
-
causal,
|
| 633 |
-
BLOCK_HEADDIM,
|
| 634 |
-
BLOCK_M=BLOCK,
|
| 635 |
-
BLOCK_N=BLOCK,
|
| 636 |
-
num_warps=num_warps,
|
| 637 |
-
num_stages=1
|
| 638 |
-
)
|
| 639 |
-
return (o, lse, softmax_scale)
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None):
|
| 643 |
-
if do.stride(-1) != 1:
|
| 644 |
-
do = do.contiguous()
|
| 645 |
-
(batch, seqlen_q, nheads, d) = q.shape
|
| 646 |
-
(_, seqlen_k, _, _) = k.shape
|
| 647 |
-
assert d <= 128
|
| 648 |
-
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
| 649 |
-
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
| 650 |
-
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
|
| 651 |
-
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
|
| 652 |
-
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
| 653 |
-
dq_accum = torch.empty_like(q, dtype=torch.float32)
|
| 654 |
-
delta = torch.empty_like(lse)
|
| 655 |
-
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
|
| 656 |
-
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
| 657 |
-
_bwd_preprocess_do_o_dot[grid](
|
| 658 |
-
o,
|
| 659 |
-
do,
|
| 660 |
-
delta,
|
| 661 |
-
o.stride(0),
|
| 662 |
-
o.stride(2),
|
| 663 |
-
o.stride(1),
|
| 664 |
-
do.stride(0),
|
| 665 |
-
do.stride(2),
|
| 666 |
-
do.stride(1),
|
| 667 |
-
nheads,
|
| 668 |
-
seqlen_q,
|
| 669 |
-
seqlen_q_rounded,
|
| 670 |
-
d,
|
| 671 |
-
BLOCK_M=128,
|
| 672 |
-
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
| 673 |
-
)
|
| 674 |
-
has_bias = bias is not None
|
| 675 |
-
bias_type = "none"
|
| 676 |
-
if has_bias:
|
| 677 |
-
assert bias.dtype in [q.dtype, torch.float]
|
| 678 |
-
assert bias.is_cuda
|
| 679 |
-
assert bias.dim() == 4
|
| 680 |
-
assert bias.stride(-1) == 1
|
| 681 |
-
if bias.shape[2:] == (1, seqlen_k):
|
| 682 |
-
bias_type = "vector"
|
| 683 |
-
elif bias.shape[2:] == (seqlen_q, seqlen_k):
|
| 684 |
-
bias_type = "matrix"
|
| 685 |
-
else:
|
| 686 |
-
raise RuntimeError("Last 2 dimensions of bias must be (1, seqlen_k) or (seqlen_q, seqlen_k)")
|
| 687 |
-
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
| 688 |
-
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
| 689 |
-
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads)
|
| 690 |
-
_bwd_kernel[grid](
|
| 691 |
-
q,
|
| 692 |
-
k,
|
| 693 |
-
v,
|
| 694 |
-
bias,
|
| 695 |
-
do,
|
| 696 |
-
dq_accum,
|
| 697 |
-
dk,
|
| 698 |
-
dv,
|
| 699 |
-
lse,
|
| 700 |
-
delta,
|
| 701 |
-
softmax_scale,
|
| 702 |
-
q.stride(0),
|
| 703 |
-
q.stride(2),
|
| 704 |
-
q.stride(1),
|
| 705 |
-
k.stride(0),
|
| 706 |
-
k.stride(2),
|
| 707 |
-
k.stride(1),
|
| 708 |
-
v.stride(0),
|
| 709 |
-
v.stride(2),
|
| 710 |
-
v.stride(1),
|
| 711 |
-
*bias_strides,
|
| 712 |
-
do.stride(0),
|
| 713 |
-
do.stride(2),
|
| 714 |
-
do.stride(1),
|
| 715 |
-
dq_accum.stride(0),
|
| 716 |
-
dq_accum.stride(2),
|
| 717 |
-
dq_accum.stride(1),
|
| 718 |
-
dk.stride(0),
|
| 719 |
-
dk.stride(2),
|
| 720 |
-
dk.stride(1),
|
| 721 |
-
dv.stride(0),
|
| 722 |
-
dv.stride(2),
|
| 723 |
-
dv.stride(1),
|
| 724 |
-
nheads,
|
| 725 |
-
seqlen_q,
|
| 726 |
-
seqlen_k,
|
| 727 |
-
seqlen_q_rounded,
|
| 728 |
-
d,
|
| 729 |
-
seqlen_q // 32,
|
| 730 |
-
seqlen_k // 32,
|
| 731 |
-
bias_type,
|
| 732 |
-
causal,
|
| 733 |
-
BLOCK_HEADDIM
|
| 734 |
-
)
|
| 735 |
-
dq.copy_(dq_accum)
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
| 739 |
-
@staticmethod
|
| 740 |
-
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
|
| 741 |
-
"""
|
| 742 |
-
qkv: (batch, seqlen, 3, nheads, headdim)
|
| 743 |
-
bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
|
| 744 |
-
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
|
| 745 |
-
ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
|
| 746 |
-
"""
|
| 747 |
-
if qkv.stride(-1) != 1:
|
| 748 |
-
qkv = qkv.contiguous()
|
| 749 |
-
(o, lse, ctx.softmax_scale) = _flash_attn_forward(qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, softmax_scale=softmax_scale)
|
| 750 |
-
ctx.save_for_backward(qkv, o, lse, bias)
|
| 751 |
-
ctx.causal = causal
|
| 752 |
-
return o
|
| 753 |
-
|
| 754 |
-
@staticmethod
|
| 755 |
-
def backward(ctx, do):
|
| 756 |
-
(qkv, o, lse, bias) = ctx.saved_tensors
|
| 757 |
-
assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
|
| 758 |
-
with torch.inference_mode():
|
| 759 |
-
dqkv = torch.empty_like(qkv)
|
| 760 |
-
_flash_attn_backward(
|
| 761 |
-
do,
|
| 762 |
-
qkv[:, :, 0],
|
| 763 |
-
qkv[:, :, 1],
|
| 764 |
-
qkv[:, :, 2],
|
| 765 |
-
o,
|
| 766 |
-
lse,
|
| 767 |
-
dqkv[:, :, 0],
|
| 768 |
-
dqkv[:, :, 1],
|
| 769 |
-
dqkv[:, :, 2],
|
| 770 |
-
bias=bias,
|
| 771 |
-
causal=ctx.causal,
|
| 772 |
-
softmax_scale=ctx.softmax_scale,
|
| 773 |
-
)
|
| 774 |
-
return (dqkv, None, None, None)
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
| 781 |
-
@staticmethod
|
| 782 |
-
def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
|
| 783 |
-
"""
|
| 784 |
-
q: (batch, seqlen_q, nheads, headdim)
|
| 785 |
-
kv: (batch, seqlen_k, 2, nheads, headdim)
|
| 786 |
-
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
| 787 |
-
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
| 788 |
-
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
| 789 |
-
"""
|
| 790 |
-
(q, kv) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
|
| 791 |
-
(o, lse, ctx.softmax_scale) = _flash_attn_forward(q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale)
|
| 792 |
-
ctx.save_for_backward(q, kv, o, lse, bias)
|
| 793 |
-
ctx.causal = causal
|
| 794 |
-
return o
|
| 795 |
-
|
| 796 |
-
@staticmethod
|
| 797 |
-
def backward(ctx, do):
|
| 798 |
-
(q, kv, o, lse, bias) = ctx.saved_tensors
|
| 799 |
-
if len(ctx.needs_input_grad) >= 3:
|
| 800 |
-
assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
|
| 801 |
-
with torch.inference_mode():
|
| 802 |
-
dq = torch.empty_like(q)
|
| 803 |
-
dkv = torch.empty_like(kv)
|
| 804 |
-
_flash_attn_backward(
|
| 805 |
-
do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale
|
| 806 |
-
)
|
| 807 |
-
return (dq, dkv, None, None, None)
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
class FlashAttnFunc(torch.autograd.Function):
|
| 814 |
-
@staticmethod
|
| 815 |
-
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
|
| 816 |
-
"""
|
| 817 |
-
q: (batch_size, seqlen_q, nheads, headdim)
|
| 818 |
-
k, v: (batch_size, seqlen_k, nheads, headdim)
|
| 819 |
-
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
|
| 820 |
-
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
|
| 821 |
-
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
|
| 822 |
-
"""
|
| 823 |
-
(q, k, v) = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
|
| 824 |
-
(o, lse, ctx.softmax_scale) = _flash_attn_forward(q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
|
| 825 |
-
ctx.save_for_backward(q, k, v, o, lse, bias)
|
| 826 |
-
ctx.causal = causal
|
| 827 |
-
return o
|
| 828 |
-
|
| 829 |
-
@staticmethod
|
| 830 |
-
def backward(ctx, do):
|
| 831 |
-
(q, k, v, o, lse, bias) = ctx.saved_tensors
|
| 832 |
-
assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
|
| 833 |
-
with torch.inference_mode():
|
| 834 |
-
dq = torch.empty_like(q)
|
| 835 |
-
dk = torch.empty_like(k)
|
| 836 |
-
dv = torch.empty_like(v)
|
| 837 |
-
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
| 838 |
-
return (dq, dk, dv, None, None, None)
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
flash_attn_func = FlashAttnFunc.apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/hf_prefixlm_converter.py
DELETED
|
@@ -1,575 +0,0 @@
|
|
| 1 |
-
"""Converts Huggingface Causal LM to Prefix LM.
|
| 2 |
-
|
| 3 |
-
Conversion does lightweight surgery on a HuggingFace
|
| 4 |
-
Causal LM to convert it to a Prefix LM.
|
| 5 |
-
|
| 6 |
-
Prefix LMs accepts a `bidirectional_mask` input in `forward`
|
| 7 |
-
and treat the input prompt as the prefix in `generate`.
|
| 8 |
-
"""
|
| 9 |
-
import math
|
| 10 |
-
import warnings
|
| 11 |
-
from types import MethodType
|
| 12 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 13 |
-
import torch
|
| 14 |
-
from transformers.models.bloom.modeling_bloom import (
|
| 15 |
-
BaseModelOutputWithPastAndCrossAttentions,
|
| 16 |
-
BloomForCausalLM,
|
| 17 |
-
BloomModel,
|
| 18 |
-
CausalLMOutputWithCrossAttentions,
|
| 19 |
-
CrossEntropyLoss,
|
| 20 |
-
)
|
| 21 |
-
from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
|
| 22 |
-
from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
|
| 23 |
-
from transformers.models.bloom.modeling_bloom import logging
|
| 24 |
-
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
| 25 |
-
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
|
| 26 |
-
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
|
| 27 |
-
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
|
| 28 |
-
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
| 29 |
-
from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
|
| 30 |
-
from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
|
| 31 |
-
|
| 32 |
-
logger = logging.get_logger(__name__)
|
| 33 |
-
_SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
|
| 34 |
-
CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
|
| 38 |
-
"""Converts a GPT-style Causal LM to a Prefix LM.
|
| 39 |
-
|
| 40 |
-
Supported HuggingFace model classes:
|
| 41 |
-
- `GPT2LMHeadModel`
|
| 42 |
-
- `GPTNeoForCausalLM`
|
| 43 |
-
- `GPTNeoXForCausalLM`
|
| 44 |
-
- `GPTJForCausalLM`
|
| 45 |
-
|
| 46 |
-
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 47 |
-
"""
|
| 48 |
-
if hasattr(model, "_prefix_lm_converted"):
|
| 49 |
-
return model
|
| 50 |
-
assert isinstance(model, _SUPPORTED_GPT_MODELS)
|
| 51 |
-
assert model.config.add_cross_attention == False, "Only supports GPT-style decoder-only models"
|
| 52 |
-
|
| 53 |
-
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
|
| 54 |
-
"""Helper that gets a list of the model's attention modules.
|
| 55 |
-
|
| 56 |
-
Each module has a `bias` buffer used for causal masking. The Prefix LM
|
| 57 |
-
conversion adds logic to dynamically manipulate these biases to support
|
| 58 |
-
Prefix LM attention masking.
|
| 59 |
-
"""
|
| 60 |
-
attn_modules = []
|
| 61 |
-
if isinstance(model, GPTNeoXForCausalLM):
|
| 62 |
-
blocks = model.gpt_neox.layers
|
| 63 |
-
else:
|
| 64 |
-
blocks = model.transformer.h
|
| 65 |
-
for block in blocks:
|
| 66 |
-
if isinstance(model, GPTNeoForCausalLM):
|
| 67 |
-
if block.attn.attention_type != "global":
|
| 68 |
-
continue
|
| 69 |
-
attn_module = block.attn.attention
|
| 70 |
-
elif isinstance(model, GPTNeoXForCausalLM):
|
| 71 |
-
attn_module = block.attention
|
| 72 |
-
else:
|
| 73 |
-
attn_module = block.attn
|
| 74 |
-
attn_modules.append(attn_module)
|
| 75 |
-
return attn_modules
|
| 76 |
-
|
| 77 |
-
setattr(model, "_original_forward", getattr(model, "forward"))
|
| 78 |
-
setattr(model, "_original_generate", getattr(model, "generate"))
|
| 79 |
-
|
| 80 |
-
def forward(
|
| 81 |
-
self: CAUSAL_GPT_TYPES,
|
| 82 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 83 |
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 84 |
-
attention_mask: Optional[torch.FloatTensor] = None,
|
| 85 |
-
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 86 |
-
token_type_ids: Optional[torch.LongTensor] = None,
|
| 87 |
-
position_ids: Optional[torch.LongTensor] = None,
|
| 88 |
-
head_mask: Optional[torch.FloatTensor] = None,
|
| 89 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 90 |
-
labels: Optional[torch.LongTensor] = None,
|
| 91 |
-
use_cache: Optional[bool] = None,
|
| 92 |
-
output_attentions: Optional[bool] = None,
|
| 93 |
-
output_hidden_states: Optional[bool] = None,
|
| 94 |
-
return_dict: Optional[bool] = None,
|
| 95 |
-
):
|
| 96 |
-
"""Wraps original forward to enable PrefixLM attention."""
|
| 97 |
-
|
| 98 |
-
def call_og_forward():
|
| 99 |
-
if isinstance(self, GPTNeoXForCausalLM):
|
| 100 |
-
return self._original_forward(
|
| 101 |
-
input_ids=input_ids,
|
| 102 |
-
past_key_values=past_key_values,
|
| 103 |
-
attention_mask=attention_mask,
|
| 104 |
-
head_mask=head_mask,
|
| 105 |
-
inputs_embeds=inputs_embeds,
|
| 106 |
-
labels=labels,
|
| 107 |
-
use_cache=use_cache,
|
| 108 |
-
output_attentions=output_attentions,
|
| 109 |
-
output_hidden_states=output_hidden_states,
|
| 110 |
-
return_dict=return_dict,
|
| 111 |
-
)
|
| 112 |
-
else:
|
| 113 |
-
return self._original_forward(
|
| 114 |
-
input_ids=input_ids,
|
| 115 |
-
past_key_values=past_key_values,
|
| 116 |
-
attention_mask=attention_mask,
|
| 117 |
-
token_type_ids=token_type_ids,
|
| 118 |
-
position_ids=position_ids,
|
| 119 |
-
head_mask=head_mask,
|
| 120 |
-
inputs_embeds=inputs_embeds,
|
| 121 |
-
labels=labels,
|
| 122 |
-
use_cache=use_cache,
|
| 123 |
-
output_attentions=output_attentions,
|
| 124 |
-
output_hidden_states=output_hidden_states,
|
| 125 |
-
return_dict=return_dict,
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
if bidirectional_mask is None:
|
| 129 |
-
return call_og_forward()
|
| 130 |
-
assert isinstance(bidirectional_mask, torch.Tensor)
|
| 131 |
-
attn_modules = _get_attn_modules(model)
|
| 132 |
-
(b, s) = bidirectional_mask.shape
|
| 133 |
-
max_length = attn_modules[0].bias.shape[-1]
|
| 134 |
-
if s > max_length:
|
| 135 |
-
raise ValueError(f"bidirectional_mask sequence length (={s}) exceeds the " + f"max length allowed by the model ({max_length}).")
|
| 136 |
-
assert s <= max_length
|
| 137 |
-
if s < max_length:
|
| 138 |
-
pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
|
| 139 |
-
bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
|
| 140 |
-
bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
|
| 141 |
-
for attn_module in attn_modules:
|
| 142 |
-
attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
|
| 143 |
-
output = call_og_forward()
|
| 144 |
-
for attn_module in attn_modules:
|
| 145 |
-
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
| 146 |
-
return output
|
| 147 |
-
|
| 148 |
-
def generate(self: CAUSAL_GPT_TYPES, *args: tuple, **kwargs: Dict[str, Any]):
|
| 149 |
-
"""Wraps original generate to enable PrefixLM attention."""
|
| 150 |
-
attn_modules = _get_attn_modules(model)
|
| 151 |
-
for attn_module in attn_modules:
|
| 152 |
-
attn_module.bias.data[:] = 1
|
| 153 |
-
output = self._original_generate(*args, **kwargs)
|
| 154 |
-
for attn_module in attn_modules:
|
| 155 |
-
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
|
| 156 |
-
return output
|
| 157 |
-
|
| 158 |
-
setattr(model, "forward", MethodType(forward, model))
|
| 159 |
-
setattr(model, "generate", MethodType(generate, model))
|
| 160 |
-
setattr(model, "_prefix_lm_converted", True)
|
| 161 |
-
return model
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
|
| 165 |
-
"""Converts a BLOOM Causal LM to a Prefix LM.
|
| 166 |
-
|
| 167 |
-
Supported HuggingFace model classes:
|
| 168 |
-
- `BloomForCausalLM`
|
| 169 |
-
|
| 170 |
-
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 171 |
-
"""
|
| 172 |
-
if hasattr(model, "_prefix_lm_converted"):
|
| 173 |
-
return model
|
| 174 |
-
assert isinstance(model, BloomForCausalLM)
|
| 175 |
-
assert model.config.add_cross_attention == False, "Only supports BLOOM decoder-only models"
|
| 176 |
-
|
| 177 |
-
def _prepare_attn_mask(
|
| 178 |
-
self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int
|
| 179 |
-
) -> torch.BoolTensor:
|
| 180 |
-
combined_attention_mask = None
|
| 181 |
-
device = attention_mask.device
|
| 182 |
-
(_, src_length) = input_shape
|
| 183 |
-
if src_length > 1:
|
| 184 |
-
combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
|
| 185 |
-
if bidirectional_mask is not None:
|
| 186 |
-
assert attention_mask.shape == bidirectional_mask.shape
|
| 187 |
-
expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
|
| 188 |
-
combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
|
| 189 |
-
expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
|
| 190 |
-
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
| 191 |
-
return combined_attention_mask
|
| 192 |
-
|
| 193 |
-
def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
| 194 |
-
num_heads = self.config.n_head
|
| 195 |
-
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
| 196 |
-
base = torch.tensor(2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
|
| 197 |
-
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
| 198 |
-
slopes = torch.pow(base, powers)
|
| 199 |
-
if closest_power_of_2 != num_heads:
|
| 200 |
-
extra_base = torch.tensor(2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))), device=device, dtype=torch.float32)
|
| 201 |
-
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
| 202 |
-
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
| 203 |
-
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
| 204 |
-
qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
|
| 205 |
-
ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
|
| 206 |
-
diffs = qa - ka + key_length - query_length
|
| 207 |
-
diffs = -diffs.abs()
|
| 208 |
-
alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
|
| 209 |
-
alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
|
| 210 |
-
return alibi.to(dtype)
|
| 211 |
-
|
| 212 |
-
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
| 213 |
-
|
| 214 |
-
def forward(
|
| 215 |
-
self: BloomModel,
|
| 216 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 217 |
-
past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
|
| 218 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 219 |
-
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 220 |
-
head_mask: Optional[torch.LongTensor] = None,
|
| 221 |
-
inputs_embeds: Optional[torch.LongTensor] = None,
|
| 222 |
-
use_cache: Optional[bool] = None,
|
| 223 |
-
output_attentions: Optional[bool] = None,
|
| 224 |
-
output_hidden_states: Optional[bool] = None,
|
| 225 |
-
return_dict: Optional[bool] = None,
|
| 226 |
-
**deprecated_arguments,
|
| 227 |
-
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
| 228 |
-
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 229 |
-
warnings.warn(
|
| 230 |
-
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. " + "You can safely ignore passing `position_ids`.", FutureWarning
|
| 231 |
-
)
|
| 232 |
-
if len(deprecated_arguments) > 0:
|
| 233 |
-
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 234 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 235 |
-
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 236 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 237 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 238 |
-
if input_ids is not None and inputs_embeds is not None:
|
| 239 |
-
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 240 |
-
elif input_ids is not None:
|
| 241 |
-
(batch_size, seq_length) = input_ids.shape
|
| 242 |
-
elif inputs_embeds is not None:
|
| 243 |
-
(batch_size, seq_length, _) = inputs_embeds.shape
|
| 244 |
-
else:
|
| 245 |
-
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 246 |
-
if past_key_values is None:
|
| 247 |
-
past_key_values = tuple([None] * len(self.h))
|
| 248 |
-
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 249 |
-
if inputs_embeds is None:
|
| 250 |
-
inputs_embeds = self.word_embeddings(input_ids)
|
| 251 |
-
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
| 252 |
-
presents = () if use_cache else None
|
| 253 |
-
all_self_attentions = () if output_attentions else None
|
| 254 |
-
all_hidden_states = () if output_hidden_states else None
|
| 255 |
-
seq_length_with_past = seq_length
|
| 256 |
-
past_key_values_length = 0
|
| 257 |
-
if past_key_values[0] is not None:
|
| 258 |
-
tmp = past_key_values[0][0]
|
| 259 |
-
past_key_values_length = tmp.shape[2]
|
| 260 |
-
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 261 |
-
if attention_mask is None:
|
| 262 |
-
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
| 263 |
-
else:
|
| 264 |
-
attention_mask = attention_mask.to(hidden_states.device)
|
| 265 |
-
alibi = self._build_alibi_tensor(
|
| 266 |
-
batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device
|
| 267 |
-
)
|
| 268 |
-
causal_mask = self._prepare_attn_mask(
|
| 269 |
-
attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length
|
| 270 |
-
)
|
| 271 |
-
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 272 |
-
if output_hidden_states:
|
| 273 |
-
hst = (hidden_states,)
|
| 274 |
-
all_hidden_states = all_hidden_states + hst
|
| 275 |
-
if self.gradient_checkpointing and self.training:
|
| 276 |
-
if use_cache:
|
| 277 |
-
logger.warning("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
| 278 |
-
use_cache = False
|
| 279 |
-
|
| 280 |
-
def create_custom_forward(module):
|
| 281 |
-
def custom_forward(*inputs):
|
| 282 |
-
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
| 283 |
-
|
| 284 |
-
return custom_forward
|
| 285 |
-
|
| 286 |
-
outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
|
| 287 |
-
else:
|
| 288 |
-
outputs = block(
|
| 289 |
-
hidden_states,
|
| 290 |
-
layer_past=layer_past,
|
| 291 |
-
attention_mask=causal_mask,
|
| 292 |
-
head_mask=head_mask[i],
|
| 293 |
-
use_cache=use_cache,
|
| 294 |
-
output_attentions=output_attentions,
|
| 295 |
-
alibi=alibi,
|
| 296 |
-
)
|
| 297 |
-
hidden_states = outputs[0]
|
| 298 |
-
if use_cache is True:
|
| 299 |
-
presents = presents + (outputs[1],)
|
| 300 |
-
if output_attentions:
|
| 301 |
-
oa = (outputs[2 if use_cache else 1],)
|
| 302 |
-
all_self_attentions = all_self_attentions + oa
|
| 303 |
-
hidden_states = self.ln_f(hidden_states)
|
| 304 |
-
if output_hidden_states:
|
| 305 |
-
hst = (hidden_states,)
|
| 306 |
-
all_hidden_states = all_hidden_states + hst
|
| 307 |
-
if not return_dict:
|
| 308 |
-
return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
|
| 309 |
-
return BaseModelOutputWithPastAndCrossAttentions(
|
| 310 |
-
last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
setattr(model.transformer, "_prepare_attn_mask", MethodType(_prepare_attn_mask, model.transformer))
|
| 314 |
-
setattr(model.transformer, "_build_alibi_tensor", MethodType(_build_alibi_tensor, model.transformer))
|
| 315 |
-
setattr(model.transformer, "forward", MethodType(forward, model.transformer))
|
| 316 |
-
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
|
| 317 |
-
|
| 318 |
-
def forward(
|
| 319 |
-
self: BloomForCausalLM,
|
| 320 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 321 |
-
past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
|
| 322 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 323 |
-
bidirectional_mask: Optional[torch.Tensor] = None,
|
| 324 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 325 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
| 326 |
-
labels: Optional[torch.Tensor] = None,
|
| 327 |
-
use_cache: Optional[bool] = None,
|
| 328 |
-
output_attentions: Optional[bool] = None,
|
| 329 |
-
output_hidden_states: Optional[bool] = None,
|
| 330 |
-
return_dict: Optional[bool] = None,
|
| 331 |
-
**deprecated_arguments,
|
| 332 |
-
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 333 |
-
"""Replacement forward method for BloomCausalLM."""
|
| 334 |
-
if deprecated_arguments.pop("position_ids", False) is not False:
|
| 335 |
-
warnings.warn(
|
| 336 |
-
"`position_ids` have no functionality in BLOOM and will be removed " + "in v5.0.0. You can safely ignore passing `position_ids`.", FutureWarning
|
| 337 |
-
)
|
| 338 |
-
if len(deprecated_arguments) > 0:
|
| 339 |
-
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
| 340 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 341 |
-
transformer_outputs = self.transformer(
|
| 342 |
-
input_ids,
|
| 343 |
-
past_key_values=past_key_values,
|
| 344 |
-
attention_mask=attention_mask,
|
| 345 |
-
bidirectional_mask=bidirectional_mask,
|
| 346 |
-
head_mask=head_mask,
|
| 347 |
-
inputs_embeds=inputs_embeds,
|
| 348 |
-
use_cache=use_cache,
|
| 349 |
-
output_attentions=output_attentions,
|
| 350 |
-
output_hidden_states=output_hidden_states,
|
| 351 |
-
return_dict=return_dict,
|
| 352 |
-
)
|
| 353 |
-
hidden_states = transformer_outputs[0]
|
| 354 |
-
lm_logits = self.lm_head(hidden_states)
|
| 355 |
-
loss = None
|
| 356 |
-
if labels is not None:
|
| 357 |
-
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 358 |
-
shift_labels = labels[..., 1:].contiguous()
|
| 359 |
-
(batch_size, seq_length, vocab_size) = shift_logits.shape
|
| 360 |
-
loss_fct = CrossEntropyLoss()
|
| 361 |
-
loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
|
| 362 |
-
if not return_dict:
|
| 363 |
-
output = (lm_logits,) + transformer_outputs[1:]
|
| 364 |
-
return (loss,) + output if loss is not None else output
|
| 365 |
-
return CausalLMOutputWithCrossAttentions(
|
| 366 |
-
loss=loss,
|
| 367 |
-
logits=lm_logits,
|
| 368 |
-
past_key_values=transformer_outputs.past_key_values,
|
| 369 |
-
hidden_states=transformer_outputs.hidden_states,
|
| 370 |
-
attentions=transformer_outputs.attentions,
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
def prepare_inputs_for_generation(
|
| 374 |
-
self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
|
| 375 |
-
) -> dict:
|
| 376 |
-
if past:
|
| 377 |
-
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 378 |
-
bidirectional_mask = None
|
| 379 |
-
if past[0][0].shape[0] == input_ids.shape[0]:
|
| 380 |
-
past = self._convert_to_bloom_cache(past)
|
| 381 |
-
else:
|
| 382 |
-
bidirectional_mask = torch.ones_like(input_ids)
|
| 383 |
-
return {"input_ids": input_ids, "past_key_values": past, "use_cache": True, "attention_mask": attention_mask, "bidirectional_mask": bidirectional_mask}
|
| 384 |
-
|
| 385 |
-
setattr(model, "forward", MethodType(forward, model))
|
| 386 |
-
setattr(model, "prepare_inputs_for_generation", MethodType(prepare_inputs_for_generation, model))
|
| 387 |
-
setattr(model, "_prefix_lm_converted", True)
|
| 388 |
-
return model
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
|
| 392 |
-
"""Converts an OPT Causal LM to a Prefix LM.
|
| 393 |
-
|
| 394 |
-
Supported HuggingFace model classes:
|
| 395 |
-
- `OPTForCausalLM`
|
| 396 |
-
|
| 397 |
-
See `convert_hf_causal_lm_to_prefix_lm` for more details.
|
| 398 |
-
"""
|
| 399 |
-
if hasattr(model, "_prefix_lm_converted"):
|
| 400 |
-
return model
|
| 401 |
-
assert isinstance(model, OPTForCausalLM)
|
| 402 |
-
assert model.config.add_cross_attention == False, "Only supports OPT decoder-only models"
|
| 403 |
-
setattr(model, "_original_forward", getattr(model, "forward"))
|
| 404 |
-
setattr(model, "_original_generate", getattr(model, "generate"))
|
| 405 |
-
model.model.decoder.bidirectional_mask = None
|
| 406 |
-
|
| 407 |
-
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
| 408 |
-
combined_attention_mask = None
|
| 409 |
-
if input_shape[-1] > 1:
|
| 410 |
-
if self.bidirectional_mask == "g":
|
| 411 |
-
(bsz, src_length) = input_shape
|
| 412 |
-
combined_attention_mask = torch.zeros(
|
| 413 |
-
(bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 414 |
-
)
|
| 415 |
-
else:
|
| 416 |
-
combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(
|
| 417 |
-
inputs_embeds.device
|
| 418 |
-
)
|
| 419 |
-
if self.bidirectional_mask is not None:
|
| 420 |
-
assert attention_mask.shape == self.bidirectional_mask.shape
|
| 421 |
-
expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
| 422 |
-
inputs_embeds.device
|
| 423 |
-
)
|
| 424 |
-
combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
|
| 425 |
-
if attention_mask is not None:
|
| 426 |
-
expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
|
| 427 |
-
combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
| 428 |
-
return combined_attention_mask
|
| 429 |
-
|
| 430 |
-
setattr(model.model.decoder, "_prepare_decoder_attention_mask", MethodType(_prepare_decoder_attention_mask, model.model.decoder))
|
| 431 |
-
|
| 432 |
-
def forward(
|
| 433 |
-
self: OPTForCausalLM,
|
| 434 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 435 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 436 |
-
bidirectional_mask: Optional[torch.ByteTensor] = None,
|
| 437 |
-
head_mask: Optional[torch.Tensor] = None,
|
| 438 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 439 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 440 |
-
labels: Optional[torch.LongTensor] = None,
|
| 441 |
-
use_cache: Optional[bool] = None,
|
| 442 |
-
output_attentions: Optional[bool] = None,
|
| 443 |
-
output_hidden_states: Optional[bool] = None,
|
| 444 |
-
return_dict: Optional[bool] = None,
|
| 445 |
-
):
|
| 446 |
-
def call_og_forward():
|
| 447 |
-
return self._original_forward(
|
| 448 |
-
input_ids=input_ids,
|
| 449 |
-
attention_mask=attention_mask,
|
| 450 |
-
head_mask=head_mask,
|
| 451 |
-
past_key_values=past_key_values,
|
| 452 |
-
inputs_embeds=inputs_embeds,
|
| 453 |
-
labels=labels,
|
| 454 |
-
use_cache=use_cache,
|
| 455 |
-
output_attentions=output_attentions,
|
| 456 |
-
output_hidden_states=output_hidden_states,
|
| 457 |
-
return_dict=return_dict,
|
| 458 |
-
)
|
| 459 |
-
|
| 460 |
-
if bidirectional_mask is None:
|
| 461 |
-
return call_og_forward()
|
| 462 |
-
self.model.decoder.bidirectional_mask = bidirectional_mask
|
| 463 |
-
try:
|
| 464 |
-
outputs = call_og_forward()
|
| 465 |
-
except:
|
| 466 |
-
self.model.decoder.bidirectional_mask = None
|
| 467 |
-
raise
|
| 468 |
-
self.model.decoder.bidirectional_mask = None
|
| 469 |
-
return outputs
|
| 470 |
-
|
| 471 |
-
def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
|
| 472 |
-
"""Wraps original generate to enable PrefixLM-style attention."""
|
| 473 |
-
self.model.decoder.bidirectional_mask = "g"
|
| 474 |
-
try:
|
| 475 |
-
output = self._original_generate(*args, **kwargs)
|
| 476 |
-
except:
|
| 477 |
-
self.model.decoder.bidirectional_mask = None
|
| 478 |
-
raise
|
| 479 |
-
self.model.decoder.bidirectional_mask = None
|
| 480 |
-
return output
|
| 481 |
-
|
| 482 |
-
setattr(model, "forward", MethodType(forward, model))
|
| 483 |
-
setattr(model, "generate", MethodType(generate, model))
|
| 484 |
-
setattr(model, "_prefix_lm_converted", True)
|
| 485 |
-
return model
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
|
| 489 |
-
CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
|
| 493 |
-
"""Converts a HuggingFace Causal LM to a Prefix LM.
|
| 494 |
-
|
| 495 |
-
Supported HuggingFace model classes:
|
| 496 |
-
- `GPT2LMHeadModel`
|
| 497 |
-
- `GPTNeoForCausalLM`
|
| 498 |
-
- `GPTNeoXForCausalLM`
|
| 499 |
-
- `GPTJForCausalLM`
|
| 500 |
-
- `BloomForCausalLM`
|
| 501 |
-
- `OPTForCausalLM`
|
| 502 |
-
|
| 503 |
-
Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
|
| 504 |
-
`generate` method and/or select underlying methods depending on the model class.
|
| 505 |
-
|
| 506 |
-
These changes preserve the model API, but add a new input to `forward`: "bidirectional_mask".
|
| 507 |
-
|
| 508 |
-
Notes on training:
|
| 509 |
-
To actually train the converted model as a Prefix LM, training batches will need to indicate
|
| 510 |
-
the prefix/target structure by including `bidirectional_mask` as part of the batch inputs.
|
| 511 |
-
|
| 512 |
-
**This is not a standard input and requires custom layers either within or after your dataloader.**
|
| 513 |
-
|
| 514 |
-
In addition to adding `bidirectional_mask` to the batch, this custom code should modify `labels`
|
| 515 |
-
such that `batch['labels'][batch['bidirectional_mask'] == 1] == -100`.
|
| 516 |
-
That is, the prefix portion of the sequence should not generate any loss. Loss should only be
|
| 517 |
-
generated by the target portion of the sequence.
|
| 518 |
-
|
| 519 |
-
Notes on `GPTNeoForCausalLM`:
|
| 520 |
-
To simplify the implementation, "global" and "local" attention layers are handled differently.
|
| 521 |
-
For "global" layers, we handle conversion as described above. For "local" layers, which use a
|
| 522 |
-
causal attention mask within a restricted local window, we do not alter the masking.
|
| 523 |
-
|
| 524 |
-
Notes on `forward` method conversion:
|
| 525 |
-
After conversion, the `forward` method will handle a new input, `bidirectional_mask`,
|
| 526 |
-
which should be a [batch_size, seq_length] byte tensor, where 1 indicates token positions
|
| 527 |
-
belonging to the prefix (prefix tokens can attend to one another bidirectionally), and
|
| 528 |
-
0 indicates token positions belonging to the target.
|
| 529 |
-
|
| 530 |
-
The new `forward` method will incorporate `bidirectional_mask` (if supplied) into the existing
|
| 531 |
-
causal mask, call the original `forward` method, and (if the causal mask is a buffer) reset
|
| 532 |
-
the causal masks before returning the result.
|
| 533 |
-
|
| 534 |
-
Notes on `generate` method conversion:
|
| 535 |
-
After conversion, the `generate` method will have the same signature but will internally
|
| 536 |
-
convert all causal masks to be purely bidirectional, call the original `generate` method, and
|
| 537 |
-
(where appropriate) reset the causal masks before returning the result.
|
| 538 |
-
|
| 539 |
-
This works thanks to the logic of the HuggingFace `generate` API, which first encodes the token
|
| 540 |
-
"prompt" passed to `generate` (which is treated as the prefix) and then sequentially generates
|
| 541 |
-
each new token. Encodings are cached as generation happens, so all prefix tokens can attend to one
|
| 542 |
-
another (as expected in a Prefix LM) and generated tokens can only attend to prefix tokens and
|
| 543 |
-
previously-generated tokens (also as expected in a Prefix LM).
|
| 544 |
-
|
| 545 |
-
To preserve the API, the original methods are renamed to `_original_forward` and
|
| 546 |
-
`_original_generate`, and replaced with new `forward` and `generate` methods that wrap
|
| 547 |
-
them, respectively. Although implementation details vary by model class.
|
| 548 |
-
"""
|
| 549 |
-
if isinstance(model, _SUPPORTED_GPT_MODELS):
|
| 550 |
-
return _convert_gpt_causal_lm_to_prefix_lm(model)
|
| 551 |
-
elif isinstance(model, BloomForCausalLM):
|
| 552 |
-
return _convert_bloom_causal_lm_to_prefix_lm(model)
|
| 553 |
-
elif isinstance(model, OPTForCausalLM):
|
| 554 |
-
return _convert_opt_causal_lm_to_prefix_lm(model)
|
| 555 |
-
else:
|
| 556 |
-
raise TypeError(f"Cannot convert model to Prefix LM. " + f"Model does not belong to set of supported HF models:" + f"\n{_SUPPORTED_HF_MODELS}")
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
|
| 560 |
-
"""Attempts to add bidirectional_mask to batch if missing.
|
| 561 |
-
|
| 562 |
-
Raises:
|
| 563 |
-
KeyError if bidirectional_mask is missing and can't be inferred
|
| 564 |
-
"""
|
| 565 |
-
if "bidirectional_mask" not in batch:
|
| 566 |
-
if batch.get("mode", None) == "icl_task":
|
| 567 |
-
batch["bidirectional_mask"] = batch["attention_mask"].clone()
|
| 568 |
-
for i, continuation_indices in enumerate(batch["continuation_indices"]):
|
| 569 |
-
batch["bidirectional_mask"][i, continuation_indices] = 0
|
| 570 |
-
elif "labels" in batch and "attention_mask" in batch:
|
| 571 |
-
batch["bidirectional_mask"] = torch.logical_and(torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)).type_as(
|
| 572 |
-
batch["attention_mask"]
|
| 573 |
-
)
|
| 574 |
-
else:
|
| 575 |
-
raise KeyError("No bidirectional_mask in batch and not sure how to construct one.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/meta_init_context.py
DELETED
|
@@ -1,98 +0,0 @@
|
|
| 1 |
-
from contextlib import contextmanager
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
@contextmanager
|
| 7 |
-
def init_empty_weights(include_buffers: bool = False):
|
| 8 |
-
"""Meta initialization context manager.
|
| 9 |
-
|
| 10 |
-
A context manager under which models are initialized with all parameters
|
| 11 |
-
on the meta device, therefore creating an empty model. Useful when just
|
| 12 |
-
initializing the model would blow the available RAM.
|
| 13 |
-
|
| 14 |
-
Args:
|
| 15 |
-
include_buffers (`bool`, *optional*, defaults to `False`): Whether or
|
| 16 |
-
not to also put all buffers on the meta device while initializing.
|
| 17 |
-
|
| 18 |
-
Example:
|
| 19 |
-
```python
|
| 20 |
-
import torch.nn as nn
|
| 21 |
-
|
| 22 |
-
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
| 23 |
-
with init_empty_weights():
|
| 24 |
-
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
| 25 |
-
```
|
| 26 |
-
|
| 27 |
-
<Tip warning={true}>
|
| 28 |
-
|
| 29 |
-
Any model created under this context manager has no weights. As such you can't do something like
|
| 30 |
-
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
| 31 |
-
|
| 32 |
-
</Tip>
|
| 33 |
-
"""
|
| 34 |
-
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
| 35 |
-
yield f
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
@contextmanager
|
| 39 |
-
def init_on_device(device: torch.device, include_buffers: bool = False):
|
| 40 |
-
"""Device initialization context manager.
|
| 41 |
-
|
| 42 |
-
A context manager under which models are initialized with all parameters
|
| 43 |
-
on the specified device.
|
| 44 |
-
|
| 45 |
-
Args:
|
| 46 |
-
device (`torch.device`): Device to initialize all parameters on.
|
| 47 |
-
include_buffers (`bool`, *optional*, defaults to `False`): Whether or
|
| 48 |
-
not to also put all buffers on the meta device while initializing.
|
| 49 |
-
|
| 50 |
-
Example:
|
| 51 |
-
```python
|
| 52 |
-
import torch.nn as nn
|
| 53 |
-
|
| 54 |
-
with init_on_device(device=torch.device("cuda")):
|
| 55 |
-
tst = nn.Liner(100, 100) # on `cuda` device
|
| 56 |
-
```
|
| 57 |
-
"""
|
| 58 |
-
old_register_parameter = nn.Module.register_parameter
|
| 59 |
-
if include_buffers:
|
| 60 |
-
old_register_buffer = nn.Module.register_buffer
|
| 61 |
-
|
| 62 |
-
def register_empty_parameter(module, name, param):
|
| 63 |
-
old_register_parameter(module, name, param)
|
| 64 |
-
if param is not None:
|
| 65 |
-
param_cls = type(module._parameters[name])
|
| 66 |
-
kwargs = module._parameters[name].__dict__
|
| 67 |
-
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 68 |
-
|
| 69 |
-
def register_empty_buffer(module, name, buffer):
|
| 70 |
-
old_register_buffer(module, name, buffer)
|
| 71 |
-
if buffer is not None:
|
| 72 |
-
module._buffers[name] = module._buffers[name].to(device)
|
| 73 |
-
|
| 74 |
-
if include_buffers:
|
| 75 |
-
tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ["empty", "zeros", "ones", "full"]}
|
| 76 |
-
else:
|
| 77 |
-
tensor_constructors_to_patch = {}
|
| 78 |
-
|
| 79 |
-
def patch_tensor_constructor(fn):
|
| 80 |
-
def wrapper(*args, **kwargs):
|
| 81 |
-
kwargs["device"] = device
|
| 82 |
-
return fn(*args, **kwargs)
|
| 83 |
-
|
| 84 |
-
return wrapper
|
| 85 |
-
|
| 86 |
-
try:
|
| 87 |
-
nn.Module.register_parameter = register_empty_parameter
|
| 88 |
-
if include_buffers:
|
| 89 |
-
nn.Module.register_buffer = register_empty_buffer
|
| 90 |
-
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 91 |
-
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
| 92 |
-
yield
|
| 93 |
-
finally:
|
| 94 |
-
nn.Module.register_parameter = old_register_parameter
|
| 95 |
-
if include_buffers:
|
| 96 |
-
nn.Module.register_buffer = old_register_buffer
|
| 97 |
-
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
| 98 |
-
setattr(torch, torch_function_name, old_torch_function)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/modeling_mpt.py
DELETED
|
@@ -1,496 +0,0 @@
|
|
| 1 |
-
"""A simple, flexible implementation of a GPT model.
|
| 2 |
-
|
| 3 |
-
Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
| 4 |
-
"""
|
| 5 |
-
import math
|
| 6 |
-
import warnings
|
| 7 |
-
from typing import List, Optional, Tuple, Union
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn as nn
|
| 11 |
-
import torch.nn.functional as F
|
| 12 |
-
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
| 13 |
-
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 14 |
-
|
| 15 |
-
from .attention import attn_bias_shape, build_attn_bias
|
| 16 |
-
from .blocks import MPTBlock
|
| 17 |
-
from .configuration_mpt import MPTConfig
|
| 18 |
-
from .custom_embedding import SharedEmbedding
|
| 19 |
-
from .norm import NORM_CLASS_REGISTRY
|
| 20 |
-
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
|
| 21 |
-
|
| 22 |
-
import torch.distributed as dist
|
| 23 |
-
|
| 24 |
-
try:
|
| 25 |
-
from .flash_attn_triton import flash_attn_func
|
| 26 |
-
except:
|
| 27 |
-
pass
|
| 28 |
-
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class MPTPreTrainedModel(PreTrainedModel):
|
| 32 |
-
config_class = MPTConfig
|
| 33 |
-
base_model_prefix = "model"
|
| 34 |
-
_no_split_modules = ["MPTBlock"]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class MPTModel(MPTPreTrainedModel):
|
| 38 |
-
def __init__(self, config: MPTConfig):
|
| 39 |
-
config._validate_config()
|
| 40 |
-
super().__init__(config)
|
| 41 |
-
self.attn_impl = config.attn_config["attn_impl"]
|
| 42 |
-
self.prefix_lm = config.attn_config["prefix_lm"]
|
| 43 |
-
self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
|
| 44 |
-
self.alibi = config.attn_config["alibi"]
|
| 45 |
-
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
| 46 |
-
if config.init_device == "mixed":
|
| 47 |
-
if dist.get_local_rank() == 0:
|
| 48 |
-
config.init_device = "cpu"
|
| 49 |
-
else:
|
| 50 |
-
config.init_device = "meta"
|
| 51 |
-
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
| 52 |
-
norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
|
| 53 |
-
raise NotImplementedError(f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).")
|
| 54 |
-
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
|
| 55 |
-
self.embedding_fraction = config.embedding_fraction
|
| 56 |
-
self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
|
| 57 |
-
if not self.alibi:
|
| 58 |
-
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
| 59 |
-
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
| 60 |
-
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
| 61 |
-
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
| 62 |
-
if config.init_device != "meta":
|
| 63 |
-
print(
|
| 64 |
-
f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
|
| 65 |
-
)
|
| 66 |
-
self.apply(self.param_init_fn)
|
| 67 |
-
self.is_causal = not self.prefix_lm
|
| 68 |
-
self._attn_bias_initialized = False
|
| 69 |
-
self.attn_bias = None
|
| 70 |
-
self.attn_bias_shape = attn_bias_shape(
|
| 71 |
-
self.attn_impl,
|
| 72 |
-
config.n_heads,
|
| 73 |
-
config.max_seq_len,
|
| 74 |
-
self.alibi,
|
| 75 |
-
prefix_lm=self.prefix_lm,
|
| 76 |
-
causal=self.is_causal,
|
| 77 |
-
use_sequence_id=self.attn_uses_sequence_id,
|
| 78 |
-
)
|
| 79 |
-
if config.no_bias:
|
| 80 |
-
for module in self.modules():
|
| 81 |
-
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
| 82 |
-
if config.verbose:
|
| 83 |
-
warnings.warn(f"Removing bias ({module.bias}) from {module}.")
|
| 84 |
-
module.register_parameter("bias", None)
|
| 85 |
-
if config.verbose and config.verbose > 2:
|
| 86 |
-
print(self)
|
| 87 |
-
if "verbose" not in self.config.init_config:
|
| 88 |
-
self.config.init_config["verbose"] = self.config.verbose
|
| 89 |
-
if self.config.init_config["verbose"] > 1:
|
| 90 |
-
init_fn_name = self.config.init_config["name"]
|
| 91 |
-
warnings.warn(f"Using {init_fn_name} initialization.")
|
| 92 |
-
|
| 93 |
-
def get_input_embeddings(self):
|
| 94 |
-
return self.wte
|
| 95 |
-
|
| 96 |
-
def set_input_embeddings(self, value):
|
| 97 |
-
self.wte = value
|
| 98 |
-
|
| 99 |
-
@torch.no_grad()
|
| 100 |
-
def _attn_bias(
|
| 101 |
-
self,
|
| 102 |
-
device,
|
| 103 |
-
dtype,
|
| 104 |
-
attention_mask: Optional[torch.ByteTensor] = None,
|
| 105 |
-
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 106 |
-
sequence_id: Optional[torch.LongTensor] = None,
|
| 107 |
-
):
|
| 108 |
-
if not self._attn_bias_initialized:
|
| 109 |
-
if self.attn_bias_shape:
|
| 110 |
-
self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
|
| 111 |
-
self.attn_bias = build_attn_bias(
|
| 112 |
-
self.attn_impl,
|
| 113 |
-
self.attn_bias,
|
| 114 |
-
self.config.n_heads,
|
| 115 |
-
self.config.max_seq_len,
|
| 116 |
-
causal=self.is_causal,
|
| 117 |
-
alibi=self.alibi,
|
| 118 |
-
alibi_bias_max=self.alibi_bias_max,
|
| 119 |
-
)
|
| 120 |
-
self._attn_bias_initialized = True
|
| 121 |
-
if self.attn_impl == "flash":
|
| 122 |
-
return (self.attn_bias, attention_mask)
|
| 123 |
-
if self.attn_bias is not None:
|
| 124 |
-
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
|
| 125 |
-
attn_bias = self.attn_bias
|
| 126 |
-
if self.prefix_lm:
|
| 127 |
-
assert isinstance(attn_bias, torch.Tensor)
|
| 128 |
-
assert isinstance(prefix_mask, torch.Tensor)
|
| 129 |
-
attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
|
| 130 |
-
if self.attn_uses_sequence_id and sequence_id is not None:
|
| 131 |
-
assert isinstance(attn_bias, torch.Tensor)
|
| 132 |
-
attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
|
| 133 |
-
if attention_mask is not None:
|
| 134 |
-
s_k = attention_mask.shape[-1]
|
| 135 |
-
if attn_bias is None:
|
| 136 |
-
attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
|
| 137 |
-
else:
|
| 138 |
-
_s_k = max(0, attn_bias.size(-1) - s_k)
|
| 139 |
-
attn_bias = attn_bias[:, :, :, _s_k:]
|
| 140 |
-
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
|
| 141 |
-
raise ValueError(f"attention_mask shape={attention_mask.shape} " + f"and prefix_mask shape={prefix_mask.shape} are not equal.")
|
| 142 |
-
min_val = torch.finfo(attn_bias.dtype).min
|
| 143 |
-
attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
|
| 144 |
-
return (attn_bias, None)
|
| 145 |
-
|
| 146 |
-
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
|
| 147 |
-
(s_k, s_q) = attn_bias.shape[-2:]
|
| 148 |
-
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
|
| 149 |
-
raise ValueError(
|
| 150 |
-
"attn_bias does not match the expected shape. "
|
| 151 |
-
+ f"The last two dimensions should both be {self.config.max_length} "
|
| 152 |
-
+ f"but are {s_k} and {s_q}."
|
| 153 |
-
)
|
| 154 |
-
seq_len = prefix_mask.shape[-1]
|
| 155 |
-
if seq_len > self.config.max_seq_len:
|
| 156 |
-
raise ValueError(f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
|
| 157 |
-
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
| 158 |
-
causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
|
| 159 |
-
prefix = prefix_mask.view(-1, 1, 1, seq_len)
|
| 160 |
-
cannot_attend = ~torch.logical_or(causal, prefix.bool())
|
| 161 |
-
min_val = torch.finfo(attn_bias.dtype).min
|
| 162 |
-
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
| 163 |
-
return attn_bias
|
| 164 |
-
|
| 165 |
-
def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
|
| 166 |
-
seq_len = sequence_id.shape[-1]
|
| 167 |
-
if seq_len > self.config.max_seq_len:
|
| 168 |
-
raise ValueError(f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
|
| 169 |
-
attn_bias = attn_bias[..., :seq_len, :seq_len]
|
| 170 |
-
cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
|
| 171 |
-
min_val = torch.finfo(attn_bias.dtype).min
|
| 172 |
-
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
|
| 173 |
-
return attn_bias
|
| 174 |
-
|
| 175 |
-
def forward(
|
| 176 |
-
self,
|
| 177 |
-
input_ids: torch.LongTensor,
|
| 178 |
-
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
| 179 |
-
attention_mask: Optional[torch.ByteTensor] = None,
|
| 180 |
-
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 181 |
-
sequence_id: Optional[torch.LongTensor] = None,
|
| 182 |
-
return_dict: Optional[bool] = None,
|
| 183 |
-
output_attentions: Optional[bool] = None,
|
| 184 |
-
output_hidden_states: Optional[bool] = None,
|
| 185 |
-
use_cache: Optional[bool] = None,
|
| 186 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
| 187 |
-
):
|
| 188 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 189 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 190 |
-
|
| 191 |
-
if attention_mask is not None:
|
| 192 |
-
attention_mask = attention_mask.bool()
|
| 193 |
-
|
| 194 |
-
if prefix_mask is not None:
|
| 195 |
-
prefix_mask = prefix_mask.bool()
|
| 196 |
-
|
| 197 |
-
# These args are passed in by keyword in huggingface's generate function
|
| 198 |
-
# https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206
|
| 199 |
-
# but have not yet been fully implemented in MPTModel
|
| 200 |
-
if not return_dict:
|
| 201 |
-
raise NotImplementedError("return_dict False is not implemented yet for MPT")
|
| 202 |
-
if output_attentions:
|
| 203 |
-
if self.attn_impl != "torch":
|
| 204 |
-
raise NotImplementedError("output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.")
|
| 205 |
-
|
| 206 |
-
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
|
| 207 |
-
raise NotImplementedError("MPT does not support training with left padding.")
|
| 208 |
-
|
| 209 |
-
if self.prefix_lm and prefix_mask is None:
|
| 210 |
-
raise ValueError("prefix_mask is a required argument when MPT is configured with prefix_lm=True.")
|
| 211 |
-
|
| 212 |
-
# Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
|
| 213 |
-
if inputs_embeds is not None:
|
| 214 |
-
raise NotImplementedError("inputs_embeds is not implemented for MPT.")
|
| 215 |
-
|
| 216 |
-
if self.training:
|
| 217 |
-
if self.attn_uses_sequence_id and sequence_id is None:
|
| 218 |
-
raise ValueError(
|
| 219 |
-
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + "and the model is in train mode."
|
| 220 |
-
)
|
| 221 |
-
elif (self.attn_uses_sequence_id is False) and (sequence_id is not None):
|
| 222 |
-
warnings.warn(
|
| 223 |
-
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
|
| 224 |
-
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
S = input_ids.size(1)
|
| 228 |
-
|
| 229 |
-
assert S <= self.config.max_seq_len, f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
|
| 230 |
-
|
| 231 |
-
tok_emb = self.wte(input_ids) # type: ignore
|
| 232 |
-
if self.alibi:
|
| 233 |
-
x = tok_emb
|
| 234 |
-
else:
|
| 235 |
-
past_position = 0
|
| 236 |
-
if past_key_values is not None:
|
| 237 |
-
if len(past_key_values) != self.config.n_layers:
|
| 238 |
-
raise ValueError(
|
| 239 |
-
f"past_key_values must provide a past_key_value for each attention "
|
| 240 |
-
+ f"layer in the network ({len(past_key_values)=}; {self.config.n_layers=})."
|
| 241 |
-
)
|
| 242 |
-
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
|
| 243 |
-
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
|
| 244 |
-
# Here we shift position embedding using the `seq` dim of the past key
|
| 245 |
-
past_position = past_key_values[0][0].size(1)
|
| 246 |
-
if self.attn_impl == "torch":
|
| 247 |
-
past_position = past_key_values[0][0].size(3)
|
| 248 |
-
|
| 249 |
-
if S + past_position > self.config.max_seq_len:
|
| 250 |
-
raise ValueError(
|
| 251 |
-
f"Cannot forward input with past sequence length {past_position} and current sequence length "
|
| 252 |
-
f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
|
| 253 |
-
)
|
| 254 |
-
pos = torch.arange(
|
| 255 |
-
past_position,
|
| 256 |
-
S + past_position,
|
| 257 |
-
dtype=torch.long,
|
| 258 |
-
device=input_ids.device,
|
| 259 |
-
).unsqueeze(0)
|
| 260 |
-
if attention_mask is not None:
|
| 261 |
-
# adjust the position indices to account for padding tokens
|
| 262 |
-
pos = torch.clamp(
|
| 263 |
-
pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:],
|
| 264 |
-
min=0,
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
pos_emb = self.wpe(pos) # type: ignore
|
| 268 |
-
x = tok_emb + pos_emb
|
| 269 |
-
|
| 270 |
-
if self.embedding_fraction == 1:
|
| 271 |
-
x = self.emb_drop(x) # type: ignore
|
| 272 |
-
else:
|
| 273 |
-
# this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414
|
| 274 |
-
x_shrunk = (x * self.embedding_fraction) + (x.detach() * (1 - self.embedding_fraction))
|
| 275 |
-
assert isinstance(self.emb_drop, nn.Module) # pyright
|
| 276 |
-
x = self.emb_drop(x_shrunk)
|
| 277 |
-
|
| 278 |
-
attn_bias, attention_mask = self._attn_bias(
|
| 279 |
-
device=x.device,
|
| 280 |
-
dtype=torch.float32,
|
| 281 |
-
attention_mask=attention_mask,
|
| 282 |
-
prefix_mask=prefix_mask,
|
| 283 |
-
sequence_id=sequence_id,
|
| 284 |
-
)
|
| 285 |
-
|
| 286 |
-
# initialize the past key values cache if it should be used
|
| 287 |
-
if use_cache and past_key_values is None:
|
| 288 |
-
past_key_values = [() for _ in range(self.config.n_layers)] # type: ignore
|
| 289 |
-
|
| 290 |
-
all_hidden_states = () if output_hidden_states else None
|
| 291 |
-
all_self_attns = () if output_attentions else None
|
| 292 |
-
for b_idx, block in enumerate(self.blocks): # type: ignore
|
| 293 |
-
if output_hidden_states:
|
| 294 |
-
assert all_hidden_states is not None # pyright
|
| 295 |
-
all_hidden_states = all_hidden_states + (x,)
|
| 296 |
-
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
| 297 |
-
x, attn_weights, past_key_value = block(
|
| 298 |
-
x,
|
| 299 |
-
past_key_value=past_key_value,
|
| 300 |
-
attn_bias=attn_bias,
|
| 301 |
-
attention_mask=attention_mask,
|
| 302 |
-
is_causal=self.is_causal,
|
| 303 |
-
)
|
| 304 |
-
if past_key_values is not None:
|
| 305 |
-
past_key_values[b_idx] = past_key_value
|
| 306 |
-
|
| 307 |
-
if output_attentions:
|
| 308 |
-
assert all_self_attns is not None # pyright
|
| 309 |
-
all_self_attns = all_self_attns + (attn_weights,)
|
| 310 |
-
|
| 311 |
-
x = self.norm_f(x) # type: ignore
|
| 312 |
-
|
| 313 |
-
# add hidden states from the last decoder layer
|
| 314 |
-
if output_hidden_states:
|
| 315 |
-
assert all_hidden_states is not None # pyright
|
| 316 |
-
all_hidden_states = all_hidden_states + (x,)
|
| 317 |
-
|
| 318 |
-
return BaseModelOutputWithPast(
|
| 319 |
-
last_hidden_state=x,
|
| 320 |
-
past_key_values=past_key_values,
|
| 321 |
-
hidden_states=all_hidden_states,
|
| 322 |
-
attentions=all_self_attns,
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
# Param Initialization, needed for device='meta' fast initialization
|
| 326 |
-
def param_init_fn(self, module):
|
| 327 |
-
init_fn_name = self.config.init_config["name"]
|
| 328 |
-
MODEL_INIT_REGISTRY[init_fn_name](
|
| 329 |
-
module=module,
|
| 330 |
-
n_layers=self.config.n_layers,
|
| 331 |
-
d_model=self.config.d_model,
|
| 332 |
-
**self.config.init_config,
|
| 333 |
-
)
|
| 334 |
-
|
| 335 |
-
def fsdp_wrap_fn(self, module):
|
| 336 |
-
return isinstance(module, MPTBlock)
|
| 337 |
-
|
| 338 |
-
def activation_checkpointing_fn(self, module):
|
| 339 |
-
return isinstance(module, MPTBlock)
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
class MPTForCausalLM(MPTPreTrainedModel):
|
| 343 |
-
def __init__(self, config: MPTConfig):
|
| 344 |
-
super().__init__(config)
|
| 345 |
-
if not config.tie_word_embeddings:
|
| 346 |
-
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
| 347 |
-
self.transformer = MPTModel(config)
|
| 348 |
-
for child in self.transformer.children():
|
| 349 |
-
if isinstance(child, torch.nn.ModuleList):
|
| 350 |
-
continue
|
| 351 |
-
if isinstance(child, torch.nn.Module):
|
| 352 |
-
child._fsdp_wrap = True
|
| 353 |
-
self.logit_scale = None
|
| 354 |
-
if config.logit_scale is not None:
|
| 355 |
-
logit_scale = config.logit_scale
|
| 356 |
-
if isinstance(logit_scale, str):
|
| 357 |
-
if logit_scale == "inv_sqrt_d_model":
|
| 358 |
-
logit_scale = 1 / math.sqrt(config.d_model)
|
| 359 |
-
else:
|
| 360 |
-
raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
|
| 361 |
-
self.logit_scale = logit_scale
|
| 362 |
-
|
| 363 |
-
def get_input_embeddings(self):
|
| 364 |
-
return self.transformer.wte
|
| 365 |
-
|
| 366 |
-
def set_input_embeddings(self, value):
|
| 367 |
-
# self.transformer.wte = value
|
| 368 |
-
peudo_wte = SharedEmbedding(value.weight.shape[0], value.weight.shape[1], device=self.transformer.wte.weight.device)
|
| 369 |
-
peudo_wte.weight = value.weight
|
| 370 |
-
self.transformer.wte = peudo_wte
|
| 371 |
-
|
| 372 |
-
def get_output_embeddings(self):
|
| 373 |
-
return self.transformer.wte
|
| 374 |
-
|
| 375 |
-
def set_output_embeddings(self, new_embeddings):
|
| 376 |
-
# self.transformer.wte = new_embeddings
|
| 377 |
-
peudo_wte = SharedEmbedding(new_embeddings.weight.shape[0], new_embeddings.weight.shape[1], device=self.transformer.wte.weight.device)
|
| 378 |
-
peudo_wte.weight = new_embeddings.weight
|
| 379 |
-
self.transformer.wte = peudo_wte
|
| 380 |
-
|
| 381 |
-
def set_decoder(self, decoder):
|
| 382 |
-
self.transformer = decoder
|
| 383 |
-
|
| 384 |
-
def get_decoder(self):
|
| 385 |
-
return self.transformer
|
| 386 |
-
|
| 387 |
-
def forward(
|
| 388 |
-
self,
|
| 389 |
-
input_ids: torch.LongTensor,
|
| 390 |
-
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
| 391 |
-
attention_mask: Optional[torch.ByteTensor] = None,
|
| 392 |
-
prefix_mask: Optional[torch.ByteTensor] = None,
|
| 393 |
-
sequence_id: Optional[torch.LongTensor] = None,
|
| 394 |
-
labels: Optional[torch.LongTensor] = None,
|
| 395 |
-
return_dict: Optional[bool] = None,
|
| 396 |
-
output_attentions: Optional[bool] = None,
|
| 397 |
-
output_hidden_states: Optional[bool] = None,
|
| 398 |
-
use_cache: Optional[bool] = None,
|
| 399 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 400 |
-
):
|
| 401 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 402 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 403 |
-
|
| 404 |
-
# if input_embeds is not none, raise a not implemented error
|
| 405 |
-
if inputs_embeds is not None:
|
| 406 |
-
raise NotImplementedError("inputs_embeds has to be None (for hf/peft support).")
|
| 407 |
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 408 |
-
outputs = self.transformer(
|
| 409 |
-
input_ids=input_ids,
|
| 410 |
-
past_key_values=past_key_values,
|
| 411 |
-
attention_mask=attention_mask,
|
| 412 |
-
prefix_mask=prefix_mask,
|
| 413 |
-
sequence_id=sequence_id,
|
| 414 |
-
return_dict=return_dict,
|
| 415 |
-
output_attentions=output_attentions,
|
| 416 |
-
output_hidden_states=output_hidden_states,
|
| 417 |
-
use_cache=use_cache,
|
| 418 |
-
)
|
| 419 |
-
|
| 420 |
-
# move outputs to same device as weights for token embedding
|
| 421 |
-
# needed to support HF `device_map`
|
| 422 |
-
logits = self.transformer.wte(
|
| 423 |
-
input=outputs.last_hidden_state.to(self.transformer.wte.weight.device),
|
| 424 |
-
unembed=True,
|
| 425 |
-
)
|
| 426 |
-
|
| 427 |
-
if self.logit_scale is not None:
|
| 428 |
-
if self.logit_scale == 0:
|
| 429 |
-
warnings.warn(f"Multiplying logits by {self.logit_scale=}. This will produce uniform (uninformative) outputs.")
|
| 430 |
-
logits *= self.logit_scale
|
| 431 |
-
|
| 432 |
-
loss = None
|
| 433 |
-
if labels is not None:
|
| 434 |
-
_labels = torch.roll(labels, shifts=-1)
|
| 435 |
-
_labels[:, -1] = -100
|
| 436 |
-
loss = F.cross_entropy(
|
| 437 |
-
logits.view(-1, logits.size(-1)),
|
| 438 |
-
_labels.to(logits.device).view(-1),
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
return CausalLMOutputWithPast(
|
| 442 |
-
loss=loss,
|
| 443 |
-
logits=logits,
|
| 444 |
-
past_key_values=outputs.past_key_values,
|
| 445 |
-
hidden_states=outputs.hidden_states,
|
| 446 |
-
attentions=outputs.attentions,
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
def param_init_fn(self, module):
|
| 450 |
-
init_fn_name = self.config.init_config["name"]
|
| 451 |
-
MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
|
| 452 |
-
|
| 453 |
-
def fsdp_wrap_fn(self, module):
|
| 454 |
-
return isinstance(module, MPTBlock)
|
| 455 |
-
|
| 456 |
-
def activation_checkpointing_fn(self, module):
|
| 457 |
-
return isinstance(module, MPTBlock)
|
| 458 |
-
|
| 459 |
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, **kwargs):
|
| 460 |
-
if inputs_embeds is not None:
|
| 461 |
-
raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
|
| 462 |
-
attention_mask = attention_mask.bool()
|
| 463 |
-
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
|
| 464 |
-
raise NotImplementedError("MPT does not support generation with right padding.")
|
| 465 |
-
if self.transformer.attn_uses_sequence_id and self.training:
|
| 466 |
-
sequence_id = torch.zeros_like(input_ids[:1])
|
| 467 |
-
else:
|
| 468 |
-
sequence_id = None
|
| 469 |
-
if past_key_values is not None:
|
| 470 |
-
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 471 |
-
if self.transformer.prefix_lm:
|
| 472 |
-
prefix_mask = torch.ones_like(attention_mask)
|
| 473 |
-
if kwargs.get("use_cache") == False:
|
| 474 |
-
raise NotImplementedError("MPT with prefix_lm=True does not support use_cache=False.")
|
| 475 |
-
else:
|
| 476 |
-
prefix_mask = None
|
| 477 |
-
return {
|
| 478 |
-
"input_ids": input_ids,
|
| 479 |
-
"attention_mask": attention_mask,
|
| 480 |
-
"prefix_mask": prefix_mask,
|
| 481 |
-
"sequence_id": sequence_id,
|
| 482 |
-
"past_key_values": past_key_values,
|
| 483 |
-
"use_cache": kwargs.get("use_cache", True),
|
| 484 |
-
}
|
| 485 |
-
|
| 486 |
-
@staticmethod
|
| 487 |
-
def _reorder_cache(past_key_values, beam_idx):
|
| 488 |
-
"""Used by HuggingFace generate when using beam search with kv-caching.
|
| 489 |
-
|
| 490 |
-
See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
|
| 491 |
-
for an example in transformers.
|
| 492 |
-
"""
|
| 493 |
-
reordered_past = []
|
| 494 |
-
for layer_past in past_key_values:
|
| 495 |
-
reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
|
| 496 |
-
return reordered_past
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/norm.py
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
def _cast_if_autocast_enabled(tensor):
|
| 5 |
-
if torch.is_autocast_enabled():
|
| 6 |
-
if tensor.device.type == "cuda":
|
| 7 |
-
dtype = torch.get_autocast_gpu_dtype()
|
| 8 |
-
elif tensor.device.type == "cpu":
|
| 9 |
-
dtype = torch.get_autocast_cpu_dtype()
|
| 10 |
-
else:
|
| 11 |
-
raise NotImplementedError()
|
| 12 |
-
return tensor.to(dtype=dtype)
|
| 13 |
-
return tensor
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class LPLayerNorm(torch.nn.LayerNorm):
|
| 17 |
-
def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
|
| 18 |
-
super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
| 19 |
-
|
| 20 |
-
def forward(self, x):
|
| 21 |
-
module_device = x.device
|
| 22 |
-
downcast_x = _cast_if_autocast_enabled(x)
|
| 23 |
-
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 24 |
-
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
| 25 |
-
with torch.autocast(enabled=False, device_type=module_device.type):
|
| 26 |
-
return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def rms_norm(x, weight=None, eps=1e-05):
|
| 30 |
-
output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 31 |
-
if weight is not None:
|
| 32 |
-
return output * weight
|
| 33 |
-
return output
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class RMSNorm(torch.nn.Module):
|
| 37 |
-
def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
|
| 38 |
-
super().__init__()
|
| 39 |
-
self.eps = eps
|
| 40 |
-
if weight:
|
| 41 |
-
self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
|
| 42 |
-
else:
|
| 43 |
-
self.register_parameter("weight", None)
|
| 44 |
-
|
| 45 |
-
def forward(self, x):
|
| 46 |
-
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class LPRMSNorm(RMSNorm):
|
| 50 |
-
def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
|
| 51 |
-
super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
|
| 52 |
-
|
| 53 |
-
def forward(self, x):
|
| 54 |
-
downcast_x = _cast_if_autocast_enabled(x)
|
| 55 |
-
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
|
| 56 |
-
with torch.autocast(enabled=False, device_type=x.device.type):
|
| 57 |
-
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
NORM_CLASS_REGISTRY = {"layernorm": torch.nn.LayerNorm, "low_precision_layernorm": LPLayerNorm, "rmsnorm": RMSNorm, "low_precision_rmsnorm": LPRMSNorm}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt/param_init_fns.py
DELETED
|
@@ -1,369 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
import warnings
|
| 3 |
-
from collections.abc import Sequence
|
| 4 |
-
from functools import partial
|
| 5 |
-
from typing import Optional, Tuple, Union
|
| 6 |
-
import torch
|
| 7 |
-
from torch import nn
|
| 8 |
-
from .norm import NORM_CLASS_REGISTRY
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
|
| 12 |
-
del kwargs
|
| 13 |
-
if verbose > 1:
|
| 14 |
-
warnings.warn(f"Initializing network using module's reset_parameters attribute")
|
| 15 |
-
if hasattr(module, "reset_parameters"):
|
| 16 |
-
module.reset_parameters()
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def fused_init_helper_(module: nn.Module, init_fn_):
|
| 20 |
-
_fused = getattr(module, "_fused", None)
|
| 21 |
-
if _fused is None:
|
| 22 |
-
raise RuntimeError(f"Internal logic error")
|
| 23 |
-
(dim, splits) = _fused
|
| 24 |
-
splits = (0, *splits, module.weight.size(dim))
|
| 25 |
-
for s, e in zip(splits[:-1], splits[1:]):
|
| 26 |
-
slice_indices = [slice(None)] * module.weight.ndim
|
| 27 |
-
slice_indices[dim] = slice(s, e)
|
| 28 |
-
init_fn_(module.weight[slice_indices])
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def generic_param_init_fn_(
|
| 32 |
-
module: nn.Module,
|
| 33 |
-
init_fn_,
|
| 34 |
-
n_layers: int,
|
| 35 |
-
d_model: Optional[int] = None,
|
| 36 |
-
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 37 |
-
emb_init_std: Optional[float] = None,
|
| 38 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 39 |
-
verbose: int = 0,
|
| 40 |
-
**kwargs,
|
| 41 |
-
):
|
| 42 |
-
del kwargs
|
| 43 |
-
if verbose > 1:
|
| 44 |
-
warnings.warn(f"If model has bias parameters they are initialized to 0.")
|
| 45 |
-
init_div_is_residual = init_div_is_residual
|
| 46 |
-
if init_div_is_residual is False:
|
| 47 |
-
div_is_residual = 1.0
|
| 48 |
-
elif init_div_is_residual is True:
|
| 49 |
-
div_is_residual = math.sqrt(2 * n_layers)
|
| 50 |
-
elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
|
| 51 |
-
div_is_residual = init_div_is_residual
|
| 52 |
-
elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
|
| 53 |
-
div_is_residual = float(init_div_is_residual)
|
| 54 |
-
else:
|
| 55 |
-
div_is_residual = 1.0
|
| 56 |
-
raise ValueError(f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}")
|
| 57 |
-
if init_div_is_residual is not False:
|
| 58 |
-
if verbose > 1:
|
| 59 |
-
warnings.warn(
|
| 60 |
-
f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
|
| 61 |
-
+ f"Set `init_div_is_residual: false` in init config to disable this."
|
| 62 |
-
)
|
| 63 |
-
if isinstance(module, nn.Linear):
|
| 64 |
-
if hasattr(module, "_fused"):
|
| 65 |
-
fused_init_helper_(module, init_fn_)
|
| 66 |
-
else:
|
| 67 |
-
init_fn_(module.weight)
|
| 68 |
-
if module.bias is not None:
|
| 69 |
-
torch.nn.init.zeros_(module.bias)
|
| 70 |
-
if init_div_is_residual is not False and getattr(module, "_is_residual", False):
|
| 71 |
-
with torch.no_grad():
|
| 72 |
-
module.weight.div_(div_is_residual)
|
| 73 |
-
elif isinstance(module, nn.Embedding):
|
| 74 |
-
if emb_init_std is not None:
|
| 75 |
-
std = emb_init_std
|
| 76 |
-
if std == 0:
|
| 77 |
-
warnings.warn(f"Embedding layer initialized to 0.")
|
| 78 |
-
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
|
| 79 |
-
if verbose > 1:
|
| 80 |
-
warnings.warn(f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}.")
|
| 81 |
-
elif emb_init_uniform_lim is not None:
|
| 82 |
-
lim = emb_init_uniform_lim
|
| 83 |
-
if isinstance(lim, Sequence):
|
| 84 |
-
if len(lim) > 2:
|
| 85 |
-
raise ValueError(f"Uniform init requires a min and a max limit. User input: {lim}.")
|
| 86 |
-
if lim[0] == lim[1]:
|
| 87 |
-
warnings.warn(f"Embedding layer initialized to {lim[0]}.")
|
| 88 |
-
else:
|
| 89 |
-
if lim == 0:
|
| 90 |
-
warnings.warn(f"Embedding layer initialized to 0.")
|
| 91 |
-
lim = [-lim, lim]
|
| 92 |
-
(a, b) = lim
|
| 93 |
-
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
|
| 94 |
-
if verbose > 1:
|
| 95 |
-
warnings.warn(f"Embedding layer initialized using uniform distribution in range {lim}.")
|
| 96 |
-
else:
|
| 97 |
-
emb_init_fn_ = init_fn_
|
| 98 |
-
emb_init_fn_(module.weight)
|
| 99 |
-
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
|
| 100 |
-
if verbose > 1:
|
| 101 |
-
warnings.warn(f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0.")
|
| 102 |
-
if hasattr(module, "weight") and module.weight is not None:
|
| 103 |
-
torch.nn.init.ones_(module.weight)
|
| 104 |
-
if hasattr(module, "bias") and module.bias is not None:
|
| 105 |
-
torch.nn.init.zeros_(module.bias)
|
| 106 |
-
elif isinstance(module, nn.MultiheadAttention):
|
| 107 |
-
if module._qkv_same_embed_dim:
|
| 108 |
-
assert module.in_proj_weight is not None
|
| 109 |
-
assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
|
| 110 |
-
assert d_model is not None
|
| 111 |
-
_d = d_model
|
| 112 |
-
splits = (0, _d, 2 * _d, 3 * _d)
|
| 113 |
-
for s, e in zip(splits[:-1], splits[1:]):
|
| 114 |
-
init_fn_(module.in_proj_weight[s:e])
|
| 115 |
-
else:
|
| 116 |
-
assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
|
| 117 |
-
assert module.in_proj_weight is None
|
| 118 |
-
init_fn_(module.q_proj_weight)
|
| 119 |
-
init_fn_(module.k_proj_weight)
|
| 120 |
-
init_fn_(module.v_proj_weight)
|
| 121 |
-
if module.in_proj_bias is not None:
|
| 122 |
-
torch.nn.init.zeros_(module.in_proj_bias)
|
| 123 |
-
if module.bias_k is not None:
|
| 124 |
-
torch.nn.init.zeros_(module.bias_k)
|
| 125 |
-
if module.bias_v is not None:
|
| 126 |
-
torch.nn.init.zeros_(module.bias_v)
|
| 127 |
-
init_fn_(module.out_proj.weight)
|
| 128 |
-
if init_div_is_residual is not False and getattr(module.out_proj, "_is_residual", False):
|
| 129 |
-
with torch.no_grad():
|
| 130 |
-
module.out_proj.weight.div_(div_is_residual)
|
| 131 |
-
if module.out_proj.bias is not None:
|
| 132 |
-
torch.nn.init.zeros_(module.out_proj.bias)
|
| 133 |
-
else:
|
| 134 |
-
for _ in module.parameters(recurse=False):
|
| 135 |
-
raise NotImplementedError(f"{module.__class__.__name__} parameters are not initialized by param_init_fn.")
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def _normal_init_(std, mean=0.0):
|
| 139 |
-
return partial(torch.nn.init.normal_, mean=mean, std=std)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def _normal_param_init_fn_(
|
| 143 |
-
module: nn.Module,
|
| 144 |
-
std: float,
|
| 145 |
-
n_layers: int,
|
| 146 |
-
d_model: Optional[int] = None,
|
| 147 |
-
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 148 |
-
emb_init_std: Optional[float] = None,
|
| 149 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 150 |
-
verbose: int = 0,
|
| 151 |
-
**kwargs,
|
| 152 |
-
):
|
| 153 |
-
del kwargs
|
| 154 |
-
init_fn_ = _normal_init_(std=std)
|
| 155 |
-
if verbose > 1:
|
| 156 |
-
warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
|
| 157 |
-
generic_param_init_fn_(
|
| 158 |
-
module=module,
|
| 159 |
-
init_fn_=init_fn_,
|
| 160 |
-
d_model=d_model,
|
| 161 |
-
n_layers=n_layers,
|
| 162 |
-
init_div_is_residual=init_div_is_residual,
|
| 163 |
-
emb_init_std=emb_init_std,
|
| 164 |
-
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 165 |
-
verbose=verbose,
|
| 166 |
-
)
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def baseline_param_init_fn_(
|
| 170 |
-
module: nn.Module,
|
| 171 |
-
init_std: float,
|
| 172 |
-
n_layers: int,
|
| 173 |
-
d_model: Optional[int] = None,
|
| 174 |
-
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 175 |
-
emb_init_std: Optional[float] = None,
|
| 176 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 177 |
-
verbose: int = 0,
|
| 178 |
-
**kwargs,
|
| 179 |
-
):
|
| 180 |
-
del kwargs
|
| 181 |
-
if init_std is None:
|
| 182 |
-
raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
|
| 183 |
-
_normal_param_init_fn_(
|
| 184 |
-
module=module,
|
| 185 |
-
std=init_std,
|
| 186 |
-
d_model=d_model,
|
| 187 |
-
n_layers=n_layers,
|
| 188 |
-
init_div_is_residual=init_div_is_residual,
|
| 189 |
-
emb_init_std=emb_init_std,
|
| 190 |
-
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 191 |
-
verbose=verbose,
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def small_param_init_fn_(
|
| 196 |
-
module: nn.Module,
|
| 197 |
-
n_layers: int,
|
| 198 |
-
d_model: int,
|
| 199 |
-
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 200 |
-
emb_init_std: Optional[float] = None,
|
| 201 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 202 |
-
verbose: int = 0,
|
| 203 |
-
**kwargs,
|
| 204 |
-
):
|
| 205 |
-
del kwargs
|
| 206 |
-
std = math.sqrt(2 / (5 * d_model))
|
| 207 |
-
_normal_param_init_fn_(
|
| 208 |
-
module=module,
|
| 209 |
-
std=std,
|
| 210 |
-
d_model=d_model,
|
| 211 |
-
n_layers=n_layers,
|
| 212 |
-
init_div_is_residual=init_div_is_residual,
|
| 213 |
-
emb_init_std=emb_init_std,
|
| 214 |
-
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 215 |
-
verbose=verbose,
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def neox_param_init_fn_(
|
| 220 |
-
module: nn.Module,
|
| 221 |
-
n_layers: int,
|
| 222 |
-
d_model: int,
|
| 223 |
-
emb_init_std: Optional[float] = None,
|
| 224 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 225 |
-
verbose: int = 0,
|
| 226 |
-
**kwargs,
|
| 227 |
-
):
|
| 228 |
-
"""From section 2.3.1 of GPT-NeoX-20B:
|
| 229 |
-
|
| 230 |
-
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
|
| 231 |
-
see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
|
| 232 |
-
and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
|
| 233 |
-
"""
|
| 234 |
-
del kwargs
|
| 235 |
-
residual_div = n_layers / math.sqrt(10)
|
| 236 |
-
if verbose > 1:
|
| 237 |
-
warnings.warn(f"setting init_div_is_residual to {residual_div}")
|
| 238 |
-
small_param_init_fn_(
|
| 239 |
-
module=module,
|
| 240 |
-
d_model=d_model,
|
| 241 |
-
n_layers=n_layers,
|
| 242 |
-
init_div_is_residual=residual_div,
|
| 243 |
-
emb_init_std=emb_init_std,
|
| 244 |
-
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 245 |
-
verbose=verbose,
|
| 246 |
-
)
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
def kaiming_uniform_param_init_fn_(
|
| 250 |
-
module: nn.Module,
|
| 251 |
-
n_layers: int,
|
| 252 |
-
d_model: Optional[int] = None,
|
| 253 |
-
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 254 |
-
emb_init_std: Optional[float] = None,
|
| 255 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 256 |
-
init_gain: float = 0,
|
| 257 |
-
fan_mode: str = "fan_in",
|
| 258 |
-
init_nonlinearity: str = "leaky_relu",
|
| 259 |
-
verbose: int = 0,
|
| 260 |
-
**kwargs,
|
| 261 |
-
):
|
| 262 |
-
del kwargs
|
| 263 |
-
if verbose > 1:
|
| 264 |
-
warnings.warn(f"Using nn.init.kaiming_uniform_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}")
|
| 265 |
-
kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 266 |
-
generic_param_init_fn_(
|
| 267 |
-
module=module,
|
| 268 |
-
init_fn_=kaiming_uniform_,
|
| 269 |
-
d_model=d_model,
|
| 270 |
-
n_layers=n_layers,
|
| 271 |
-
init_div_is_residual=init_div_is_residual,
|
| 272 |
-
emb_init_std=emb_init_std,
|
| 273 |
-
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 274 |
-
verbose=verbose,
|
| 275 |
-
)
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
def kaiming_normal_param_init_fn_(
|
| 279 |
-
module: nn.Module,
|
| 280 |
-
n_layers: int,
|
| 281 |
-
d_model: Optional[int] = None,
|
| 282 |
-
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 283 |
-
emb_init_std: Optional[float] = None,
|
| 284 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 285 |
-
init_gain: float = 0,
|
| 286 |
-
fan_mode: str = "fan_in",
|
| 287 |
-
init_nonlinearity: str = "leaky_relu",
|
| 288 |
-
verbose: int = 0,
|
| 289 |
-
**kwargs,
|
| 290 |
-
):
|
| 291 |
-
del kwargs
|
| 292 |
-
if verbose > 1:
|
| 293 |
-
warnings.warn(f"Using nn.init.kaiming_normal_ init fn with parameters: " + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}")
|
| 294 |
-
kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
|
| 295 |
-
generic_param_init_fn_(
|
| 296 |
-
module=module,
|
| 297 |
-
init_fn_=kaiming_normal_,
|
| 298 |
-
d_model=d_model,
|
| 299 |
-
n_layers=n_layers,
|
| 300 |
-
init_div_is_residual=init_div_is_residual,
|
| 301 |
-
emb_init_std=emb_init_std,
|
| 302 |
-
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 303 |
-
verbose=verbose,
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
def xavier_uniform_param_init_fn_(
|
| 308 |
-
module: nn.Module,
|
| 309 |
-
n_layers: int,
|
| 310 |
-
d_model: Optional[int] = None,
|
| 311 |
-
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 312 |
-
emb_init_std: Optional[float] = None,
|
| 313 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 314 |
-
init_gain: float = 0,
|
| 315 |
-
verbose: int = 0,
|
| 316 |
-
**kwargs,
|
| 317 |
-
):
|
| 318 |
-
del kwargs
|
| 319 |
-
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
|
| 320 |
-
if verbose > 1:
|
| 321 |
-
warnings.warn(f"Using torch.nn.init.xavier_uniform_ init fn with parameters: " + f"gain={init_gain}")
|
| 322 |
-
generic_param_init_fn_(
|
| 323 |
-
module=module,
|
| 324 |
-
init_fn_=xavier_uniform_,
|
| 325 |
-
d_model=d_model,
|
| 326 |
-
n_layers=n_layers,
|
| 327 |
-
init_div_is_residual=init_div_is_residual,
|
| 328 |
-
emb_init_std=emb_init_std,
|
| 329 |
-
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 330 |
-
verbose=verbose,
|
| 331 |
-
)
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
def xavier_normal_param_init_fn_(
|
| 335 |
-
module: nn.Module,
|
| 336 |
-
n_layers: int,
|
| 337 |
-
d_model: Optional[int] = None,
|
| 338 |
-
init_div_is_residual: Union[int, float, str, bool] = True,
|
| 339 |
-
emb_init_std: Optional[float] = None,
|
| 340 |
-
emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
|
| 341 |
-
init_gain: float = 0,
|
| 342 |
-
verbose: int = 0,
|
| 343 |
-
**kwargs,
|
| 344 |
-
):
|
| 345 |
-
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
|
| 346 |
-
if verbose > 1:
|
| 347 |
-
warnings.warn(f"Using torch.nn.init.xavier_normal_ init fn with parameters: " + f"gain={init_gain}")
|
| 348 |
-
generic_param_init_fn_(
|
| 349 |
-
module=module,
|
| 350 |
-
init_fn_=xavier_normal_,
|
| 351 |
-
d_model=d_model,
|
| 352 |
-
n_layers=n_layers,
|
| 353 |
-
init_div_is_residual=init_div_is_residual,
|
| 354 |
-
emb_init_std=emb_init_std,
|
| 355 |
-
emb_init_uniform_lim=emb_init_uniform_lim,
|
| 356 |
-
verbose=verbose,
|
| 357 |
-
)
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
MODEL_INIT_REGISTRY = {
|
| 361 |
-
"default_": torch_default_param_init_fn_,
|
| 362 |
-
"baseline_": baseline_param_init_fn_,
|
| 363 |
-
"kaiming_uniform_": kaiming_uniform_param_init_fn_,
|
| 364 |
-
"kaiming_normal_": kaiming_normal_param_init_fn_,
|
| 365 |
-
"neox_init_": neox_param_init_fn_,
|
| 366 |
-
"small_init_": small_param_init_fn_,
|
| 367 |
-
"xavier_uniform_": xavier_uniform_param_init_fn_,
|
| 368 |
-
"xavier_normal_": xavier_normal_param_init_fn_,
|
| 369 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mllm/flamingo/mpt_redpajama/__init__.py
DELETED
|
File without changes
|
mllm/flamingo/mpt_redpajama/__pycache__/__init__.cpython-39.pyc
DELETED
|
Binary file (216 Bytes)
|
|
|