Upload 4 files
#1
by
phoebeklett
- opened
- attention.py +5 -6
- blocks.py +1 -1
- configuration.py +5 -0
- modeling_mpt.py +15 -19
attention.py
CHANGED
|
@@ -95,10 +95,10 @@ def scaled_multihead_dot_product_attention(
|
|
| 95 |
)
|
| 96 |
attn_weight = attn_weight + attn_bias
|
| 97 |
|
| 98 |
-
if needs_weights:
|
| 99 |
reshaped_idx = None
|
| 100 |
if long_range_past_key_value is not None or faiss_indexes is not None:
|
| 101 |
-
if long_range_past_key_value is not None: #manual
|
| 102 |
|
| 103 |
k_cache, v_cache = long_range_past_key_value
|
| 104 |
s_cache = k_cache.size(-1)
|
|
@@ -134,15 +134,14 @@ def scaled_multihead_dot_product_attention(
|
|
| 134 |
|
| 135 |
selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:d], '(h s) d -> 1 h d s', h=32).to(q.device)
|
| 136 |
selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,d:], '(h s) d -> 1 h s d', h=32).to(q.device)
|
| 137 |
-
|
| 138 |
s_k_ae = selected_k.size(-1)
|
| 139 |
s_k += s_k_ae
|
| 140 |
attn_weight_cache = q.matmul(selected_k) * softmax_scale
|
| 141 |
if mask_by_sim:
|
| 142 |
attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, min_val)
|
| 143 |
|
| 144 |
-
if attn_bias_ae is not None:
|
| 145 |
-
# clamp to 0 necessary for torch 2.0 compile()
|
| 146 |
_s_q = max(0, attn_bias_ae.size(2) - s_q)
|
| 147 |
_s_k = max(0, attn_bias_ae.size(3) - s_k_ae)
|
| 148 |
attn_bias_ae = attn_bias_ae[:, :, _s_q:, _s_k:]
|
|
@@ -710,7 +709,7 @@ def build_attn_bias(
|
|
| 710 |
for_ae=for_ae,
|
| 711 |
topk=topk
|
| 712 |
))
|
| 713 |
-
else:
|
| 714 |
attn_bias = build_alibi_bias(
|
| 715 |
n_heads,
|
| 716 |
seq_len,
|
|
|
|
| 95 |
)
|
| 96 |
attn_weight = attn_weight + attn_bias
|
| 97 |
|
| 98 |
+
if needs_weights: #will return memory indices w/attention weights
|
| 99 |
reshaped_idx = None
|
| 100 |
if long_range_past_key_value is not None or faiss_indexes is not None:
|
| 101 |
+
if long_range_past_key_value is not None: #manual memories
|
| 102 |
|
| 103 |
k_cache, v_cache = long_range_past_key_value
|
| 104 |
s_cache = k_cache.size(-1)
|
|
|
|
| 134 |
|
| 135 |
selected_k=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,:d], '(h s) d -> 1 h d s', h=32).to(q.device)
|
| 136 |
selected_v=rearrange(torch.tensor(kv_index.reconstruct_batch(I.flatten()))[:,d:], '(h s) d -> 1 h s d', h=32).to(q.device)
|
| 137 |
+
|
| 138 |
s_k_ae = selected_k.size(-1)
|
| 139 |
s_k += s_k_ae
|
| 140 |
attn_weight_cache = q.matmul(selected_k) * softmax_scale
|
| 141 |
if mask_by_sim:
|
| 142 |
attn_weight_cache = attn_weight_cache.masked_fill(sim_mask, min_val)
|
| 143 |
|
| 144 |
+
if attn_bias_ae is not None: #add alibi bias to memories
|
|
|
|
| 145 |
_s_q = max(0, attn_bias_ae.size(2) - s_q)
|
| 146 |
_s_k = max(0, attn_bias_ae.size(3) - s_k_ae)
|
| 147 |
attn_bias_ae = attn_bias_ae[:, :, _s_q:, _s_k:]
|
|
|
|
| 709 |
for_ae=for_ae,
|
| 710 |
topk=topk
|
| 711 |
))
|
| 712 |
+
else: #for memories
|
| 713 |
attn_bias = build_alibi_bias(
|
| 714 |
n_heads,
|
| 715 |
seq_len,
|
blocks.py
CHANGED
|
@@ -7,7 +7,7 @@
|
|
| 7 |
from typing import Dict, Optional, Tuple
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
-
from .attention import ATTN_CLASS_REGISTRY
|
| 11 |
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
| 12 |
|
| 13 |
class MPTMLP(nn.Module):
|
|
|
|
| 7 |
from typing import Dict, Optional, Tuple
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
+
from extended_mind_transformers.mpt.attention import ATTN_CLASS_REGISTRY
|
| 11 |
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
| 12 |
|
| 13 |
class MPTMLP(nn.Module):
|
configuration.py
CHANGED
|
@@ -165,6 +165,11 @@ class ExtendedMPTConfig(PretrainedConfig):
|
|
| 165 |
init_config_defaults,
|
| 166 |
)
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
if self.d_model % self.n_heads != 0:
|
| 169 |
raise ValueError('d_model must be divisible by n_heads')
|
| 170 |
if any(
|
|
|
|
| 165 |
init_config_defaults,
|
| 166 |
)
|
| 167 |
|
| 168 |
+
if self.attn_config['memory_type']=='faiss' and self.attn_config['mask_by_sim'] is True:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
'mask_by_sim is not supported for faiss memory type.'
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
if self.d_model % self.n_heads != 0:
|
| 174 |
raise ValueError('d_model must be divisible by n_heads')
|
| 175 |
if any(
|
modeling_mpt.py
CHANGED
|
@@ -27,10 +27,10 @@ from llmfoundry.models.layers.custom_embedding import SharedEmbedding
|
|
| 27 |
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
| 28 |
from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
|
| 29 |
|
| 30 |
-
from .configuration import ExtendedMPTConfig
|
| 31 |
-
from .attention import attn_bias_shape, build_attn_bias
|
| 32 |
-
from .blocks import MPTBlock
|
| 33 |
-
from .utils import instantiate_from_config
|
| 34 |
|
| 35 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
| 36 |
|
|
@@ -111,7 +111,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
| 111 |
causal=self.is_causal,
|
| 112 |
use_sequence_id=self.attn_uses_sequence_id,
|
| 113 |
)
|
| 114 |
-
self._attn_bias_ae_initialized = False
|
| 115 |
self.attn_bias_ae = None
|
| 116 |
|
| 117 |
if self.config.no_bias:
|
|
@@ -168,7 +168,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
| 168 |
)
|
| 169 |
self._attn_bias_initialized = True
|
| 170 |
|
| 171 |
-
if use_active_externalism:
|
| 172 |
self.attn_bias_ae = build_attn_bias(
|
| 173 |
self.attn_impl,
|
| 174 |
self.config.n_heads,
|
|
@@ -196,7 +196,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
| 196 |
|
| 197 |
attn_bias = self.attn_bias
|
| 198 |
|
| 199 |
-
if self.attn_bias_ae is not None:
|
| 200 |
self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
|
| 201 |
attn_bias_ae = self.attn_bias_ae
|
| 202 |
|
|
@@ -417,9 +417,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
| 417 |
assert isinstance(self.emb_drop, nn.Module) # pyright
|
| 418 |
x = self.emb_drop(x_shrunk)
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
seq_len = S
|
| 423 |
if past_key_values is not None:
|
| 424 |
past_position = past_key_values[0][0].size(-1)
|
| 425 |
seq_len += past_position
|
|
@@ -493,7 +491,7 @@ class ExtendedMPTModel(MPTPreTrainedModel):
|
|
| 493 |
last_hidden_state=x,
|
| 494 |
past_key_values=past_key_values,
|
| 495 |
hidden_states=all_hidden_states,
|
| 496 |
-
attentions=(all_self_attns, all_idx),
|
| 497 |
)
|
| 498 |
|
| 499 |
# Param Initialization, needed for device='meta' fast initialization
|
|
@@ -598,7 +596,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
| 598 |
use_active_externalism: Optional[bool]=None,
|
| 599 |
topk:int=None
|
| 600 |
):
|
| 601 |
-
if self._memories is not None and self.memories is None:
|
| 602 |
self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
|
| 603 |
|
| 604 |
return_dict = (return_dict
|
|
@@ -702,9 +700,8 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
| 702 |
prev_end_loc=0
|
| 703 |
long_range_past_key_values = None
|
| 704 |
faiss_indexes= None
|
| 705 |
-
for b_idx in range(0, input_ids.size(-1), stride):
|
| 706 |
end_loc = min(b_idx + max_len, input_ids.size(-1))
|
| 707 |
-
|
| 708 |
trg_len = end_loc - prev_end_loc
|
| 709 |
subseq = input_ids[:, b_idx:end_loc].to(self.device)
|
| 710 |
with torch.no_grad():
|
|
@@ -734,7 +731,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
| 734 |
if long_range_past_key_values is not None and faiss_indexes is not None:
|
| 735 |
raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
|
| 736 |
|
| 737 |
-
if cache_type=='faiss':
|
| 738 |
one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
|
| 739 |
if faiss_indexes is None:
|
| 740 |
faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
|
|
@@ -747,7 +744,6 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
| 747 |
k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
|
| 748 |
v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
|
| 749 |
kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
|
| 750 |
-
|
| 751 |
else:
|
| 752 |
if long_range_past_key_values is None:
|
| 753 |
long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
|
|
@@ -759,8 +755,8 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
| 759 |
)
|
| 760 |
for ind, kv in enumerate(long_range_past_key_values)
|
| 761 |
]
|
| 762 |
-
if long_range_past_key_values is not None:
|
| 763 |
-
if long_range_past_key_values[0][0].size(-1) > max_length_cache:
|
| 764 |
long_range_past_key_values = [
|
| 765 |
(
|
| 766 |
kv[0][:, :, :, -max_length_cache:],
|
|
@@ -816,7 +812,7 @@ class ExtendedMPTForCausalLM(MPTPreTrainedModel):
|
|
| 816 |
'sequence_id': sequence_id,
|
| 817 |
'past_key_values': past_key_values,
|
| 818 |
'use_cache': kwargs.get('use_cache', True),
|
| 819 |
-
'use_active_externalism': kwargs.get('use_active_externalism'),
|
| 820 |
'topk': kwargs.get('topk', None),
|
| 821 |
}
|
| 822 |
|
|
|
|
| 27 |
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
|
| 28 |
from llmfoundry.models.utils.param_init_fns import MODEL_INIT_REGISTRY
|
| 29 |
|
| 30 |
+
from extended_mind_transformers.mpt.configuration import ExtendedMPTConfig
|
| 31 |
+
from extended_mind_transformers.mpt.attention import attn_bias_shape, build_attn_bias
|
| 32 |
+
from extended_mind_transformers.mpt.blocks import MPTBlock
|
| 33 |
+
from extended_mind_transformers.utils import instantiate_from_config
|
| 34 |
|
| 35 |
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
| 36 |
|
|
|
|
| 111 |
causal=self.is_causal,
|
| 112 |
use_sequence_id=self.attn_uses_sequence_id,
|
| 113 |
)
|
| 114 |
+
self._attn_bias_ae_initialized = False #for active externalism
|
| 115 |
self.attn_bias_ae = None
|
| 116 |
|
| 117 |
if self.config.no_bias:
|
|
|
|
| 168 |
)
|
| 169 |
self._attn_bias_initialized = True
|
| 170 |
|
| 171 |
+
if use_active_externalism: #for active externalism, init every time since seq_len changes
|
| 172 |
self.attn_bias_ae = build_attn_bias(
|
| 173 |
self.attn_impl,
|
| 174 |
self.config.n_heads,
|
|
|
|
| 196 |
|
| 197 |
attn_bias = self.attn_bias
|
| 198 |
|
| 199 |
+
if self.attn_bias_ae is not None: #for active externalism
|
| 200 |
self.attn_bias_ae = self.attn_bias_ae.to(dtype=dtype, device=device)
|
| 201 |
attn_bias_ae = self.attn_bias_ae
|
| 202 |
|
|
|
|
| 417 |
assert isinstance(self.emb_drop, nn.Module) # pyright
|
| 418 |
x = self.emb_drop(x_shrunk)
|
| 419 |
|
| 420 |
+
seq_len = S #for active externalism
|
|
|
|
|
|
|
| 421 |
if past_key_values is not None:
|
| 422 |
past_position = past_key_values[0][0].size(-1)
|
| 423 |
seq_len += past_position
|
|
|
|
| 491 |
last_hidden_state=x,
|
| 492 |
past_key_values=past_key_values,
|
| 493 |
hidden_states=all_hidden_states,
|
| 494 |
+
attentions=(all_self_attns, all_idx), #return reshaped_idx for active externalism
|
| 495 |
)
|
| 496 |
|
| 497 |
# Param Initialization, needed for device='meta' fast initialization
|
|
|
|
| 596 |
use_active_externalism: Optional[bool]=None,
|
| 597 |
topk:int=None
|
| 598 |
):
|
| 599 |
+
if self._memories is not None and self.memories is None: #init memories once on first call
|
| 600 |
self.memories = self.generate_cache(self._memories, cache_type=self.memory_type)
|
| 601 |
|
| 602 |
return_dict = (return_dict
|
|
|
|
| 700 |
prev_end_loc=0
|
| 701 |
long_range_past_key_values = None
|
| 702 |
faiss_indexes= None
|
| 703 |
+
for b_idx in range(0, input_ids.size(-1), stride): #generate kv-pairs using stride
|
| 704 |
end_loc = min(b_idx + max_len, input_ids.size(-1))
|
|
|
|
| 705 |
trg_len = end_loc - prev_end_loc
|
| 706 |
subseq = input_ids[:, b_idx:end_loc].to(self.device)
|
| 707 |
with torch.no_grad():
|
|
|
|
| 731 |
if long_range_past_key_values is not None and faiss_indexes is not None:
|
| 732 |
raise NotImplementedError("Using faiss and passing key value pairs manually are mutually exclusive right now.")
|
| 733 |
|
| 734 |
+
if cache_type=='faiss': #add one-hot encoding to match layer, head indices
|
| 735 |
one_hot_encodings = F.one_hot(torch.arange(0, self.config.n_heads*self.config.n_layers))*10
|
| 736 |
if faiss_indexes is None:
|
| 737 |
faiss_indexes = (faiss.IndexFlatIP(to_cache[0][0].size(-2)+one_hot_encodings.size(-1)), faiss.IndexFlatIP(to_cache[0][1].size(-1)*2))
|
|
|
|
| 744 |
k= rearrange(k, 'b h d s -> b (h s) d', h=self.config.n_heads)
|
| 745 |
v= rearrange(v, 'b h s d -> b (h s) d', h=self.config.n_heads)
|
| 746 |
kv_index.add(torch.concat([v.squeeze(), k.squeeze()], dim=1).to('cpu').numpy())
|
|
|
|
| 747 |
else:
|
| 748 |
if long_range_past_key_values is None:
|
| 749 |
long_range_past_key_values = [(k.to(self.memory_device),v.to(self.memory_device)) for k,v in to_cache]
|
|
|
|
| 755 |
)
|
| 756 |
for ind, kv in enumerate(long_range_past_key_values)
|
| 757 |
]
|
| 758 |
+
if long_range_past_key_values is not None: #set a limit on manual memory length
|
| 759 |
+
if long_range_past_key_values[0][0].size(-1) > max_length_cache:
|
| 760 |
long_range_past_key_values = [
|
| 761 |
(
|
| 762 |
kv[0][:, :, :, -max_length_cache:],
|
|
|
|
| 812 |
'sequence_id': sequence_id,
|
| 813 |
'past_key_values': past_key_values,
|
| 814 |
'use_cache': kwargs.get('use_cache', True),
|
| 815 |
+
'use_active_externalism': kwargs.get('use_active_externalism'), #add a few more kwargs for active externalism
|
| 816 |
'topk': kwargs.get('topk', None),
|
| 817 |
}
|
| 818 |
|