Commit ·
ff010c9
1
Parent(s): b97888f
Update modeling_bluelm.py
Browse files- modeling_bluelm.py +20 -22
modeling_bluelm.py
CHANGED
|
@@ -32,7 +32,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
|
|
| 32 |
from transformers.modeling_utils import PreTrainedModel
|
| 33 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 34 |
from .configuration_bluelm import BlueLMConfig
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
try:
|
| 38 |
from xformers import ops as xops
|
|
@@ -213,6 +218,11 @@ class BlueLMAttention(nn.Module):
|
|
| 213 |
hidden_size,
|
| 214 |
bias=False,
|
| 215 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
self.rotary_emb = BlueLMRotaryEmbedding(self.head_dim)
|
| 217 |
if xops is not None:
|
| 218 |
self.causal_mask = xops.LowerTriangularMask()
|
|
@@ -230,7 +240,8 @@ class BlueLMAttention(nn.Module):
|
|
| 230 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 231 |
"""Input shape: Batch x Time x Channel"""
|
| 232 |
|
| 233 |
-
bsz, q_len,
|
|
|
|
| 234 |
|
| 235 |
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
| 236 |
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
@@ -245,7 +256,7 @@ class BlueLMAttention(nn.Module):
|
|
| 245 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)
|
| 246 |
# [bsz, t, nh, hd]
|
| 247 |
|
| 248 |
-
if
|
| 249 |
# reuse k, v, self_attention
|
| 250 |
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
| 251 |
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
|
@@ -260,25 +271,12 @@ class BlueLMAttention(nn.Module):
|
|
| 260 |
)
|
| 261 |
else:
|
| 262 |
# [bsz, t, nh, hd]
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
f" {attn_weights.size()}"
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
if attention_mask is not None:
|
| 272 |
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 273 |
-
raise ValueError(
|
| 274 |
-
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 275 |
-
)
|
| 276 |
-
attn_weights = attn_weights + attention_mask
|
| 277 |
-
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
| 278 |
|
| 279 |
-
# upcast attention to fp32
|
| 280 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 281 |
-
attn_output = torch.einsum("bnqk,bknh->bqnh", attn_weights, value_states)
|
| 282 |
|
| 283 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
| 284 |
raise ValueError(
|
|
@@ -612,7 +610,7 @@ class BlueLMModel(BlueLMPreTrainedModel):
|
|
| 612 |
seq_length_with_past = seq_length
|
| 613 |
past_key_values_length = 0
|
| 614 |
if past_key_values is not None:
|
| 615 |
-
past_key_values_length = past_key_values[0][0].shape[
|
| 616 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 617 |
if inputs_embeds is None:
|
| 618 |
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
| 32 |
from transformers.modeling_utils import PreTrainedModel
|
| 33 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 34 |
from .configuration_bluelm import BlueLMConfig
|
| 35 |
+
from flash_attn.flash_attn_interface import (
|
| 36 |
+
flash_attn_func,
|
| 37 |
+
flash_attn_kvpacked_func,
|
| 38 |
+
flash_attn_qkvpacked_func,
|
| 39 |
+
flash_attn_varlen_kvpacked_func,
|
| 40 |
+
)
|
| 41 |
|
| 42 |
try:
|
| 43 |
from xformers import ops as xops
|
|
|
|
| 218 |
hidden_size,
|
| 219 |
bias=False,
|
| 220 |
)
|
| 221 |
+
self.register_buffer(
|
| 222 |
+
"norm_factor",
|
| 223 |
+
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
| 224 |
+
persistent=False,
|
| 225 |
+
)
|
| 226 |
self.rotary_emb = BlueLMRotaryEmbedding(self.head_dim)
|
| 227 |
if xops is not None:
|
| 228 |
self.causal_mask = xops.LowerTriangularMask()
|
|
|
|
| 240 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 241 |
"""Input shape: Batch x Time x Channel"""
|
| 242 |
|
| 243 |
+
bsz, q_len, h_size = hidden_states.size()
|
| 244 |
+
has_layer_past = past_key_value is not None
|
| 245 |
|
| 246 |
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
| 247 |
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
|
|
|
|
| 256 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)
|
| 257 |
# [bsz, t, nh, hd]
|
| 258 |
|
| 259 |
+
if has_layer_past:
|
| 260 |
# reuse k, v, self_attention
|
| 261 |
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
| 262 |
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
|
|
|
| 271 |
)
|
| 272 |
else:
|
| 273 |
# [bsz, t, nh, hd]
|
| 274 |
+
kv = torch.stack([key_states, value_states], 2)
|
| 275 |
+
attn_outputs = flash_attn_kvpacked_func(
|
| 276 |
+
query_states, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions)
|
| 277 |
+
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
| 278 |
+
attn_weights = attn_outputs[2] if output_attentions else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
| 282 |
raise ValueError(
|
|
|
|
| 610 |
seq_length_with_past = seq_length
|
| 611 |
past_key_values_length = 0
|
| 612 |
if past_key_values is not None:
|
| 613 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 614 |
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 615 |
if inputs_embeds is None:
|
| 616 |
inputs_embeds = self.embed_tokens(input_ids)
|