Update modeling_baichuan.py
Browse files- modeling_baichuan.py +16 -14
modeling_baichuan.py
CHANGED
|
@@ -59,7 +59,7 @@ def _make_causal_mask(
|
|
| 59 |
Make causal mask used for bi-directional self-attention.
|
| 60 |
"""
|
| 61 |
bsz, tgt_len = input_ids_shape
|
| 62 |
-
mask = torch.full((tgt_len, tgt_len),
|
| 63 |
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 64 |
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 65 |
mask = mask.to(dtype)
|
|
@@ -109,15 +109,14 @@ class RMSNorm(nn.Module):
|
|
| 109 |
class RotaryEmbedding(torch.nn.Module):
|
| 110 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 111 |
super().__init__()
|
| 112 |
-
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 113 |
self.max_seq_len_cached = max_position_embeddings
|
| 114 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
| 115 |
freqs = torch.outer(t, self.inv_freq)
|
| 116 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 117 |
-
self.
|
| 118 |
-
self.
|
| 119 |
-
|
| 120 |
-
def forward(self, x, seq_len):
|
| 121 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 122 |
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
| 123 |
if seq_len > self.max_seq_len_cached:
|
|
@@ -125,11 +124,14 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 125 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
| 126 |
freqs = torch.outer(t, self.inv_freq)
|
| 127 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 128 |
-
self.
|
| 129 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 130 |
return (
|
| 131 |
-
self.cos_cached[:, :, :seq_len,
|
| 132 |
-
self.sin_cached[:, :, :seq_len,
|
| 133 |
)
|
| 134 |
|
| 135 |
|
|
@@ -208,7 +210,7 @@ class Attention(nn.Module):
|
|
| 208 |
|
| 209 |
kv_seq_len = key_states.shape[-2]
|
| 210 |
if past_key_value is not None:
|
| 211 |
-
kv_seq_len =
|
| 212 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 213 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 214 |
# [bsz, nh, t, hd]
|
|
@@ -228,8 +230,8 @@ class Attention(nn.Module):
|
|
| 228 |
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
| 229 |
)
|
| 230 |
else:
|
| 231 |
-
|
| 232 |
-
|
| 233 |
attn_output = attn_output.transpose(1, 2)
|
| 234 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 235 |
attn_output = self.o_proj(attn_output)
|
|
@@ -701,4 +703,4 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
| 701 |
else:
|
| 702 |
outputs = self.generate(input_ids, generation_config=generation_config)
|
| 703 |
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
| 704 |
-
return response
|
|
|
|
| 59 |
Make causal mask used for bi-directional self-attention.
|
| 60 |
"""
|
| 61 |
bsz, tgt_len = input_ids_shape
|
| 62 |
+
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
|
| 63 |
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 64 |
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 65 |
mask = mask.to(dtype)
|
|
|
|
| 109 |
class RotaryEmbedding(torch.nn.Module):
|
| 110 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 111 |
super().__init__()
|
| 112 |
+
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
| 113 |
self.max_seq_len_cached = max_position_embeddings
|
| 114 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
| 115 |
freqs = torch.outer(t, self.inv_freq)
|
| 116 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 117 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32)
|
| 118 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32)
|
| 119 |
+
def forward(self, x, seq_len=None):
|
|
|
|
| 120 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 121 |
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
| 122 |
if seq_len > self.max_seq_len_cached:
|
|
|
|
| 124 |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
|
| 125 |
freqs = torch.outer(t, self.inv_freq)
|
| 126 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 127 |
+
self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device)
|
| 128 |
+
self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device)
|
| 129 |
+
elif self.cos_cached.device != x.device:
|
| 130 |
+
self.cos_cached = self.cos_cached.to(x.device)
|
| 131 |
+
self.sin_cached = self.sin_cached.to(x.device)
|
| 132 |
return (
|
| 133 |
+
self.cos_cached[:, :, :seq_len, ...],
|
| 134 |
+
self.sin_cached[:, :, :seq_len, ...],
|
| 135 |
)
|
| 136 |
|
| 137 |
|
|
|
|
| 210 |
|
| 211 |
kv_seq_len = key_states.shape[-2]
|
| 212 |
if past_key_value is not None:
|
| 213 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 214 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 215 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 216 |
# [bsz, nh, t, hd]
|
|
|
|
| 230 |
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
| 231 |
)
|
| 232 |
else:
|
| 233 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
| 234 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask = attention_mask)
|
| 235 |
attn_output = attn_output.transpose(1, 2)
|
| 236 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 237 |
attn_output = self.o_proj(attn_output)
|
|
|
|
| 703 |
else:
|
| 704 |
outputs = self.generate(input_ids, generation_config=generation_config)
|
| 705 |
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
| 706 |
+
return response
|