Update modeling_Llamoe.py
Browse files- modeling_Llamoe.py +73 -31
modeling_Llamoe.py
CHANGED
|
@@ -167,41 +167,58 @@ class LlamoeRMSNorm(nn.Module):
|
|
| 167 |
return self.weight * hidden_states.to(input_dtype)
|
| 168 |
|
| 169 |
|
| 170 |
-
|
|
|
|
| 171 |
class LlamoeRotaryEmbedding(nn.Module):
|
| 172 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 173 |
super().__init__()
|
| 174 |
-
|
| 175 |
self.dim = dim
|
| 176 |
self.max_position_embeddings = max_position_embeddings
|
| 177 |
self.base = base
|
| 178 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
| 179 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
self._set_cos_sin_cache(
|
| 183 |
-
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 187 |
-
self.max_seq_len_cached = seq_len
|
| 188 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 189 |
-
|
| 190 |
freqs = torch.outer(t, self.inv_freq)
|
| 191 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 192 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 193 |
-
self.register_buffer("
|
| 194 |
-
self.register_buffer("
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
| 204 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
@@ -212,8 +229,8 @@ def rotate_half(x):
|
|
| 212 |
return torch.cat((-x2, x1), dim=-1)
|
| 213 |
|
| 214 |
|
| 215 |
-
# Copied from transformers.models.
|
| 216 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 217 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 218 |
|
| 219 |
Args:
|
|
@@ -221,9 +238,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
| 221 |
k (`torch.Tensor`): The key tensor.
|
| 222 |
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 223 |
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 224 |
-
position_ids (`torch.Tensor
|
| 225 |
-
|
| 226 |
-
used to pass offsetted position ids when working with a KV-cache.
|
| 227 |
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 228 |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 229 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
@@ -234,8 +250,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
| 234 |
Returns:
|
| 235 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 236 |
"""
|
| 237 |
-
cos = cos
|
| 238 |
-
sin = sin
|
| 239 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 240 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 241 |
return q_embed, k_embed
|
|
@@ -254,7 +270,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
| 254 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 255 |
|
| 256 |
|
| 257 |
-
# Copied from transformers.models.mistral.modeling_mistral.
|
| 258 |
class LlamoeAttention(nn.Module):
|
| 259 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 260 |
|
|
@@ -873,11 +889,11 @@ LLAMOE_START_DOCSTRING = r"""
|
|
| 873 |
)
|
| 874 |
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
|
| 875 |
class LlamoePreTrainedModel(PreTrainedModel):
|
| 876 |
-
config_class =
|
| 877 |
base_model_prefix = "model"
|
| 878 |
supports_gradient_checkpointing = True
|
| 879 |
_no_split_modules = ["LlamoeDecoderLayer"]
|
| 880 |
-
_skip_keys_device_placement = "past_key_values"
|
| 881 |
_supports_flash_attn_2 = True
|
| 882 |
_supports_sdpa = True
|
| 883 |
_supports_cache_class = True
|
|
@@ -893,6 +909,32 @@ class LlamoePreTrainedModel(PreTrainedModel):
|
|
| 893 |
if module.padding_idx is not None:
|
| 894 |
module.weight.data[module.padding_idx].zero_()
|
| 895 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
|
| 897 |
LLAMOE_INPUTS_DOCSTRING = r"""
|
| 898 |
Args:
|
|
|
|
| 167 |
return self.weight * hidden_states.to(input_dtype)
|
| 168 |
|
| 169 |
|
| 170 |
+
|
| 171 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral
|
| 172 |
class LlamoeRotaryEmbedding(nn.Module):
|
| 173 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
| 174 |
super().__init__()
|
| 175 |
+
self.scaling_factor = scaling_factor
|
| 176 |
self.dim = dim
|
| 177 |
self.max_position_embeddings = max_position_embeddings
|
| 178 |
self.base = base
|
| 179 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
| 180 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 181 |
+
# For BC we register cos and sin cached
|
| 182 |
+
self.max_seq_len_cached = max_position_embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
| 184 |
+
t = t / self.scaling_factor
|
| 185 |
freqs = torch.outer(t, self.inv_freq)
|
| 186 |
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 187 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 188 |
+
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
|
| 189 |
+
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
|
| 190 |
|
| 191 |
+
@property
|
| 192 |
+
def sin_cached(self):
|
| 193 |
+
logger.warning_once(
|
| 194 |
+
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
| 195 |
+
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
| 196 |
+
)
|
| 197 |
+
return self._sin_cached
|
| 198 |
|
| 199 |
+
@property
|
| 200 |
+
def cos_cached(self):
|
| 201 |
+
logger.warning_once(
|
| 202 |
+
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
| 203 |
+
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
| 204 |
)
|
| 205 |
+
return self._cos_cached
|
| 206 |
+
|
| 207 |
+
@torch.no_grad()
|
| 208 |
+
def forward(self, x, position_ids):
|
| 209 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 210 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
| 211 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 212 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
| 213 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
| 214 |
+
device_type = x.device.type
|
| 215 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 216 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 217 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 218 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 219 |
+
cos = emb.cos()
|
| 220 |
+
sin = emb.sin()
|
| 221 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 222 |
|
| 223 |
|
| 224 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
|
|
| 229 |
return torch.cat((-x2, x1), dim=-1)
|
| 230 |
|
| 231 |
|
| 232 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
| 233 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 234 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 235 |
|
| 236 |
Args:
|
|
|
|
| 238 |
k (`torch.Tensor`): The key tensor.
|
| 239 |
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 240 |
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 241 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 242 |
+
Deprecated and unused.
|
|
|
|
| 243 |
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 244 |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 245 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
|
|
| 250 |
Returns:
|
| 251 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 252 |
"""
|
| 253 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 254 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 255 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 256 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 257 |
return q_embed, k_embed
|
|
|
|
| 270 |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 271 |
|
| 272 |
|
| 273 |
+
# Copied from transformers.models.mistral.modeling_mistral.LlamaAttention with Llama->Mixtral
|
| 274 |
class LlamoeAttention(nn.Module):
|
| 275 |
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 276 |
|
|
|
|
| 889 |
)
|
| 890 |
# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral
|
| 891 |
class LlamoePreTrainedModel(PreTrainedModel):
|
| 892 |
+
config_class = LlamaConfig
|
| 893 |
base_model_prefix = "model"
|
| 894 |
supports_gradient_checkpointing = True
|
| 895 |
_no_split_modules = ["LlamoeDecoderLayer"]
|
| 896 |
+
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
|
| 897 |
_supports_flash_attn_2 = True
|
| 898 |
_supports_sdpa = True
|
| 899 |
_supports_cache_class = True
|
|
|
|
| 909 |
if module.padding_idx is not None:
|
| 910 |
module.weight.data[module.padding_idx].zero_()
|
| 911 |
|
| 912 |
+
def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
|
| 913 |
+
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
|
| 914 |
+
raise ValueError(
|
| 915 |
+
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
| 916 |
+
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
|
| 920 |
+
causal_mask = torch.full(
|
| 921 |
+
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
|
| 922 |
+
)
|
| 923 |
+
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
| 924 |
+
|
| 925 |
+
for layer in self.model.layers:
|
| 926 |
+
device = layer.input_layernorm.weight.device
|
| 927 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
| 928 |
+
dtype = self.config._pre_quantization_dtype
|
| 929 |
+
else:
|
| 930 |
+
dtype = layer.self_attn.o_proj.weight.dtype
|
| 931 |
+
layer.self_attn.past_key_value = cache_cls(
|
| 932 |
+
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
def _reset_cache(self):
|
| 936 |
+
for layer in self.model.layers:
|
| 937 |
+
layer.self_attn.past_key_value = None
|
| 938 |
|
| 939 |
LLAMOE_INPUTS_DOCSTRING = r"""
|
| 940 |
Args:
|