Update modelling_RW.py
Browse files- modelling_RW.py +41 -35
modelling_RW.py
CHANGED
|
@@ -175,32 +175,42 @@ class Attention(nn.Module):
|
|
| 175 |
|
| 176 |
self.query_key_value = Linear(
|
| 177 |
self.hidden_size,
|
| 178 |
-
|
| 179 |
bias=config.bias,
|
| 180 |
)
|
| 181 |
-
self.multi_query = config.multi_query
|
| 182 |
self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
|
| 183 |
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 184 |
-
self.num_kv = config.
|
| 185 |
|
| 186 |
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 187 |
"""
|
| 188 |
-
Split the last dimension into (num_heads, head_dim)
|
| 189 |
storage as `fused_qkv`
|
| 190 |
Args:
|
| 191 |
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
| 192 |
Returns:
|
| 193 |
-
query: [batch_size, seq_length, num_heads, head_dim]
|
|
|
|
| 194 |
value: [batch_size, seq_length, num_heads, head_dim]
|
| 195 |
"""
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
|
| 205 |
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 206 |
"""
|
|
@@ -244,11 +254,11 @@ class Attention(nn.Module):
|
|
| 244 |
|
| 245 |
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 246 |
key_layer = key_layer.transpose(1, 2).reshape(
|
| 247 |
-
batch_size * self.
|
| 248 |
q_length,
|
| 249 |
self.head_dim,
|
| 250 |
)
|
| 251 |
-
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.
|
| 252 |
|
| 253 |
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
| 254 |
|
|
@@ -269,8 +279,8 @@ class Attention(nn.Module):
|
|
| 269 |
|
| 270 |
if alibi is None:
|
| 271 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 272 |
-
key_layer_ = key_layer.reshape(batch_size, self.
|
| 273 |
-
value_layer_ = value_layer.reshape(batch_size, self.
|
| 274 |
|
| 275 |
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, torch.finfo(torch.float16).min).to(query_layer_.dtype)
|
| 276 |
attn_output = F.scaled_dot_product_attention(
|
|
@@ -300,7 +310,8 @@ class Attention(nn.Module):
|
|
| 300 |
attention_scores = attention_scores.to(torch.float32)
|
| 301 |
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
| 302 |
attention_probs = F.softmax(
|
| 303 |
-
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
|
|
|
|
| 304 |
dim=-1,
|
| 305 |
dtype=hidden_states.dtype,
|
| 306 |
)
|
|
@@ -349,14 +360,12 @@ class DecoderLayer(nn.Module):
|
|
| 349 |
super().__init__()
|
| 350 |
hidden_size = config.hidden_size
|
| 351 |
|
| 352 |
-
self.
|
|
|
|
|
|
|
| 353 |
self.num_heads = config.n_head
|
| 354 |
self.self_attention = Attention(config)
|
| 355 |
|
| 356 |
-
if not config.parallel_attn:
|
| 357 |
-
# unused if parallel attn
|
| 358 |
-
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 359 |
-
|
| 360 |
self.mlp = MLP(config)
|
| 361 |
|
| 362 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
|
@@ -375,12 +384,14 @@ class DecoderLayer(nn.Module):
|
|
| 375 |
output_attentions: bool = False,
|
| 376 |
):
|
| 377 |
|
| 378 |
-
|
|
|
|
|
|
|
| 379 |
residual = hidden_states
|
| 380 |
|
| 381 |
# Self attention.
|
| 382 |
attn_outputs = self.self_attention(
|
| 383 |
-
|
| 384 |
layer_past=layer_past,
|
| 385 |
attention_mask=attention_mask,
|
| 386 |
alibi=alibi,
|
|
@@ -391,19 +402,14 @@ class DecoderLayer(nn.Module):
|
|
| 391 |
|
| 392 |
attention_output = attn_outputs[0]
|
| 393 |
|
| 394 |
-
if not self.config.parallel_attn:
|
| 395 |
-
residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
|
| 396 |
-
layernorm_output = self.post_attention_layernorm(residual)
|
| 397 |
-
|
| 398 |
outputs = attn_outputs[1:]
|
| 399 |
|
| 400 |
# MLP.
|
| 401 |
-
mlp_output = self.mlp(
|
| 402 |
-
|
| 403 |
-
if self.config.parallel_attn:
|
| 404 |
-
mlp_output += attention_output
|
| 405 |
|
| 406 |
-
output = dropout_add(
|
|
|
|
|
|
|
| 407 |
|
| 408 |
if use_cache:
|
| 409 |
outputs = (output,) + outputs
|
|
@@ -1093,4 +1099,4 @@ class RWForQuestionAnswering(RWPreTrainedModel):
|
|
| 1093 |
end_logits=end_logits,
|
| 1094 |
hidden_states=outputs.hidden_states,
|
| 1095 |
attentions=outputs.attentions,
|
| 1096 |
-
)
|
|
|
|
| 175 |
|
| 176 |
self.query_key_value = Linear(
|
| 177 |
self.hidden_size,
|
| 178 |
+
(config.n_head_kv * 2 + config.n_head) * self.head_dim,
|
| 179 |
bias=config.bias,
|
| 180 |
)
|
|
|
|
| 181 |
self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
|
| 182 |
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
| 183 |
+
self.num_kv = config.n_head_kv
|
| 184 |
|
| 185 |
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 186 |
"""
|
| 187 |
+
Split the last dimension into (num_heads, head_dim), results share same memory
|
| 188 |
storage as `fused_qkv`
|
| 189 |
Args:
|
| 190 |
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
| 191 |
Returns:
|
| 192 |
+
query: [batch_size, seq_length, num_heads, head_dim]
|
| 193 |
+
key: [batch_size, seq_length, num_heads, head_dim]
|
| 194 |
value: [batch_size, seq_length, num_heads, head_dim]
|
| 195 |
"""
|
| 196 |
+
batch, seq_len, _ = fused_qkv.shape
|
| 197 |
+
qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv + 2, 64)
|
| 198 |
+
q = qkv[:, :, :, :-2]
|
| 199 |
+
k = qkv[:, :, :, [-2]]
|
| 200 |
+
v = qkv[:, :, :, [-1]]
|
| 201 |
+
k = torch.broadcast_to(k, q.shape)
|
| 202 |
+
v = torch.broadcast_to(v, q.shape)
|
| 203 |
+
|
| 204 |
+
q, k, v = [
|
| 205 |
+
rearrange(
|
| 206 |
+
x,
|
| 207 |
+
"batch seq_len group num_heads head_dim ->\
|
| 208 |
+
batch seq_len (group num_heads) head_dim",
|
| 209 |
+
head_dim=self.head_dim,
|
| 210 |
+
)
|
| 211 |
+
for x in [q, k, v]
|
| 212 |
+
]
|
| 213 |
+
return q, k, v
|
| 214 |
|
| 215 |
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 216 |
"""
|
|
|
|
| 254 |
|
| 255 |
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 256 |
key_layer = key_layer.transpose(1, 2).reshape(
|
| 257 |
+
batch_size * self.num_heads,
|
| 258 |
q_length,
|
| 259 |
self.head_dim,
|
| 260 |
)
|
| 261 |
+
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
| 262 |
|
| 263 |
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
| 264 |
|
|
|
|
| 279 |
|
| 280 |
if alibi is None:
|
| 281 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 282 |
+
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 283 |
+
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 284 |
|
| 285 |
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, torch.finfo(torch.float16).min).to(query_layer_.dtype)
|
| 286 |
attn_output = F.scaled_dot_product_attention(
|
|
|
|
| 310 |
attention_scores = attention_scores.to(torch.float32)
|
| 311 |
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
| 312 |
attention_probs = F.softmax(
|
| 313 |
+
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
|
| 314 |
+
+ attention_mask_float,
|
| 315 |
dim=-1,
|
| 316 |
dtype=hidden_states.dtype,
|
| 317 |
)
|
|
|
|
| 360 |
super().__init__()
|
| 361 |
hidden_size = config.hidden_size
|
| 362 |
|
| 363 |
+
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 364 |
+
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 365 |
+
|
| 366 |
self.num_heads = config.n_head
|
| 367 |
self.self_attention = Attention(config)
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
self.mlp = MLP(config)
|
| 370 |
|
| 371 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
|
|
|
| 384 |
output_attentions: bool = False,
|
| 385 |
):
|
| 386 |
|
| 387 |
+
ln_attn = self.ln_attn(hidden_states)
|
| 388 |
+
ln_mlp = self.ln_mlp(hidden_states)
|
| 389 |
+
|
| 390 |
residual = hidden_states
|
| 391 |
|
| 392 |
# Self attention.
|
| 393 |
attn_outputs = self.self_attention(
|
| 394 |
+
ln_attn,
|
| 395 |
layer_past=layer_past,
|
| 396 |
attention_mask=attention_mask,
|
| 397 |
alibi=alibi,
|
|
|
|
| 402 |
|
| 403 |
attention_output = attn_outputs[0]
|
| 404 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
outputs = attn_outputs[1:]
|
| 406 |
|
| 407 |
# MLP.
|
| 408 |
+
mlp_output = self.mlp(ln_mlp)
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
+
output = dropout_add(
|
| 411 |
+
mlp_output + attention_output, residual, self.config.hidden_dropout, training=self.training
|
| 412 |
+
)
|
| 413 |
|
| 414 |
if use_cache:
|
| 415 |
outputs = (output,) + outputs
|
|
|
|
| 1099 |
end_logits=end_logits,
|
| 1100 |
hidden_states=outputs.hidden_states,
|
| 1101 |
attentions=outputs.attentions,
|
| 1102 |
+
)
|