merge
Browse files- config.json +1 -1
- model.safetensors +0 -3
- modeling_aria.py +25 -135
config.json
CHANGED
|
@@ -10,7 +10,7 @@
|
|
| 10 |
"model_type": "aria",
|
| 11 |
"num_attention_heads": 24,
|
| 12 |
"num_hidden_layers": 16,
|
| 13 |
-
"torch_dtype": "
|
| 14 |
"transformers_version": "4.45.0",
|
| 15 |
"use_cache": true,
|
| 16 |
"vocab_size": 17727,
|
|
|
|
| 10 |
"model_type": "aria",
|
| 11 |
"num_attention_heads": 24,
|
| 12 |
"num_hidden_layers": 16,
|
| 13 |
+
"torch_dtype": "float32",
|
| 14 |
"transformers_version": "4.45.0",
|
| 15 |
"use_cache": true,
|
| 16 |
"vocab_size": 17727,
|
model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:9057480d90c91e0b9000f365ceafcbd7e21cd1940dc4bb25f1bd328cbe26c28f
|
| 3 |
-
size 2634170640
|
|
|
|
|
|
|
|
|
|
|
|
modeling_aria.py
CHANGED
|
@@ -180,13 +180,11 @@ class TransformerBlock(nn.Module):
|
|
| 180 |
xk, xv, self.layer_idx, cache_kwargs
|
| 181 |
)
|
| 182 |
|
| 183 |
-
# scaled_dot_product_attention expects: (b_sz, n_head, s_len, d_head)
|
| 184 |
att = F.scaled_dot_product_attention(
|
| 185 |
query=xq,
|
| 186 |
key=xk,
|
| 187 |
value=xv,
|
| 188 |
-
attn_mask=attention_mask,
|
| 189 |
-
# is_causal=True,
|
| 190 |
)
|
| 191 |
|
| 192 |
# Reshape for out: (b_sz, s_len, n_head, d_head)
|
|
@@ -215,6 +213,7 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 215 |
super().__init__(model_config)
|
| 216 |
self.model_config = model_config
|
| 217 |
self.freqs_cis = None
|
|
|
|
| 218 |
|
| 219 |
self.tok_embeddings = nn.Embedding(
|
| 220 |
num_embeddings=model_config.vocab_size,
|
|
@@ -341,13 +340,10 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 341 |
position_ids = cache_position.unsqueeze(0)
|
| 342 |
hidden_states = inputs_embeds
|
| 343 |
|
| 344 |
-
causal_mask
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
past_key_values,
|
| 349 |
-
output_attentions,
|
| 350 |
-
)
|
| 351 |
|
| 352 |
if self.freqs_cis is None:
|
| 353 |
self.freqs_cis = precompute_freqs_cis(
|
|
@@ -360,6 +356,19 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 360 |
|
| 361 |
freqs_cis = self.freqs_cis[cache_position]
|
| 362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
kwargs = {
|
| 364 |
"position_ids": position_ids,
|
| 365 |
"past_key_values": past_key_values,
|
|
@@ -432,130 +441,6 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 432 |
attentions=all_attentions,
|
| 433 |
)
|
| 434 |
|
| 435 |
-
def _update_causal_mask(
|
| 436 |
-
self,
|
| 437 |
-
attention_mask: torch.Tensor,
|
| 438 |
-
input_tensor: torch.Tensor,
|
| 439 |
-
cache_position: torch.Tensor,
|
| 440 |
-
past_key_values: Cache,
|
| 441 |
-
output_attentions: bool,
|
| 442 |
-
):
|
| 443 |
-
if self.model_config._attn_implementation == "flash_attention_2":
|
| 444 |
-
if attention_mask is not None and (attention_mask == 0.0).any():
|
| 445 |
-
return attention_mask
|
| 446 |
-
return None
|
| 447 |
-
|
| 448 |
-
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
| 449 |
-
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
| 450 |
-
# to infer the attention mask.
|
| 451 |
-
past_seen_tokens = (
|
| 452 |
-
past_key_values.get_seq_length()
|
| 453 |
-
if past_key_values is not None
|
| 454 |
-
else 0
|
| 455 |
-
)
|
| 456 |
-
using_static_cache = isinstance(past_key_values, StaticCache)
|
| 457 |
-
|
| 458 |
-
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
| 459 |
-
if (
|
| 460 |
-
self.model_config._attn_implementation == "sdpa"
|
| 461 |
-
and not using_static_cache
|
| 462 |
-
and not output_attentions
|
| 463 |
-
):
|
| 464 |
-
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
| 465 |
-
attention_mask,
|
| 466 |
-
inputs_embeds=input_tensor,
|
| 467 |
-
past_key_values_length=past_seen_tokens,
|
| 468 |
-
is_training=self.training,
|
| 469 |
-
):
|
| 470 |
-
return None
|
| 471 |
-
|
| 472 |
-
dtype, device = input_tensor.dtype, input_tensor.device
|
| 473 |
-
sequence_length = input_tensor.shape[1]
|
| 474 |
-
if using_static_cache:
|
| 475 |
-
target_length = past_key_values.get_max_cache_shape()
|
| 476 |
-
else:
|
| 477 |
-
target_length = (
|
| 478 |
-
attention_mask.shape[-1]
|
| 479 |
-
if isinstance(attention_mask, torch.Tensor)
|
| 480 |
-
else past_seen_tokens + sequence_length + 1
|
| 481 |
-
)
|
| 482 |
-
|
| 483 |
-
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
| 484 |
-
causal_mask = (
|
| 485 |
-
self._prepare_4d_causal_attention_mask_with_cache_position(
|
| 486 |
-
attention_mask,
|
| 487 |
-
sequence_length=sequence_length,
|
| 488 |
-
target_length=target_length,
|
| 489 |
-
dtype=dtype,
|
| 490 |
-
device=device,
|
| 491 |
-
cache_position=cache_position,
|
| 492 |
-
batch_size=input_tensor.shape[0],
|
| 493 |
-
)
|
| 494 |
-
)
|
| 495 |
-
|
| 496 |
-
if (
|
| 497 |
-
self.model_config._attn_implementation == "sdpa"
|
| 498 |
-
and attention_mask is not None
|
| 499 |
-
and attention_mask.device.type == "cuda"
|
| 500 |
-
and not output_attentions
|
| 501 |
-
):
|
| 502 |
-
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
| 503 |
-
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
| 504 |
-
# Details: https://github.com/pytorch/pytorch/issues/110213
|
| 505 |
-
min_dtype = torch.finfo(dtype).min
|
| 506 |
-
causal_mask = AttentionMaskConverter._unmask_unattended(
|
| 507 |
-
causal_mask, min_dtype
|
| 508 |
-
)
|
| 509 |
-
|
| 510 |
-
return causal_mask
|
| 511 |
-
|
| 512 |
-
@staticmethod
|
| 513 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
| 514 |
-
def _prepare_4d_causal_attention_mask_with_cache_position(
|
| 515 |
-
attention_mask: torch.Tensor,
|
| 516 |
-
sequence_length: int,
|
| 517 |
-
target_length: int,
|
| 518 |
-
dtype: torch.dtype,
|
| 519 |
-
device: torch.device,
|
| 520 |
-
cache_position: torch.Tensor,
|
| 521 |
-
batch_size: int,
|
| 522 |
-
**kwargs,
|
| 523 |
-
):
|
| 524 |
-
if attention_mask is not None and attention_mask.dim() == 4:
|
| 525 |
-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
| 526 |
-
causal_mask = attention_mask
|
| 527 |
-
else:
|
| 528 |
-
min_dtype = torch.finfo(dtype).min
|
| 529 |
-
causal_mask = torch.full(
|
| 530 |
-
(sequence_length, target_length),
|
| 531 |
-
fill_value=min_dtype,
|
| 532 |
-
dtype=dtype,
|
| 533 |
-
device=device,
|
| 534 |
-
)
|
| 535 |
-
if sequence_length != 1:
|
| 536 |
-
causal_mask = torch.triu(causal_mask, diagonal=1)
|
| 537 |
-
causal_mask *= torch.arange(
|
| 538 |
-
target_length, device=device
|
| 539 |
-
) > cache_position.reshape(-1, 1)
|
| 540 |
-
causal_mask = causal_mask[None, None, :, :].expand(
|
| 541 |
-
batch_size, 1, -1, -1
|
| 542 |
-
)
|
| 543 |
-
if attention_mask is not None:
|
| 544 |
-
causal_mask = (
|
| 545 |
-
causal_mask.clone()
|
| 546 |
-
) # copy to contiguous memory for in-place edit
|
| 547 |
-
mask_length = attention_mask.shape[-1]
|
| 548 |
-
padding_mask = (
|
| 549 |
-
causal_mask[:, :, :, :mask_length]
|
| 550 |
-
+ attention_mask[:, None, None, :]
|
| 551 |
-
)
|
| 552 |
-
padding_mask = padding_mask == 0
|
| 553 |
-
causal_mask[:, :, :, :mask_length] = causal_mask[
|
| 554 |
-
:, :, :, :mask_length
|
| 555 |
-
].masked_fill(padding_mask, min_dtype)
|
| 556 |
-
|
| 557 |
-
return causal_mask
|
| 558 |
-
|
| 559 |
|
| 560 |
class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
|
| 561 |
"""Transformer decoder with head for language modelling.
|
|
@@ -732,6 +617,12 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
|
|
| 732 |
)
|
| 733 |
|
| 734 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
def precompute_freqs_cis(
|
| 736 |
seq_len: int,
|
| 737 |
n_elem: int,
|
|
@@ -749,7 +640,6 @@ def precompute_freqs_cis(
|
|
| 749 |
return cache.to(dtype=dtype)
|
| 750 |
|
| 751 |
|
| 752 |
-
@torch.jit.script
|
| 753 |
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 754 |
"""
|
| 755 |
In-place RoPE. Credits to Katherine Crowson:
|
|
|
|
| 180 |
xk, xv, self.layer_idx, cache_kwargs
|
| 181 |
)
|
| 182 |
|
|
|
|
| 183 |
att = F.scaled_dot_product_attention(
|
| 184 |
query=xq,
|
| 185 |
key=xk,
|
| 186 |
value=xv,
|
| 187 |
+
attn_mask=attention_mask[..., : xk.shape[2]],
|
|
|
|
| 188 |
)
|
| 189 |
|
| 190 |
# Reshape for out: (b_sz, s_len, n_head, d_head)
|
|
|
|
| 213 |
super().__init__(model_config)
|
| 214 |
self.model_config = model_config
|
| 215 |
self.freqs_cis = None
|
| 216 |
+
self.causal_mask = None
|
| 217 |
|
| 218 |
self.tok_embeddings = nn.Embedding(
|
| 219 |
num_embeddings=model_config.vocab_size,
|
|
|
|
| 340 |
position_ids = cache_position.unsqueeze(0)
|
| 341 |
hidden_states = inputs_embeds
|
| 342 |
|
| 343 |
+
if self.causal_mask is None:
|
| 344 |
+
self.causal_mask = precompute_causal_mask(
|
| 345 |
+
max_seq_len=self.model_config.max_seq_len,
|
| 346 |
+
).to(input_ids.device)
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
if self.freqs_cis is None:
|
| 349 |
self.freqs_cis = precompute_freqs_cis(
|
|
|
|
| 356 |
|
| 357 |
freqs_cis = self.freqs_cis[cache_position]
|
| 358 |
|
| 359 |
+
if use_cache is True:
|
| 360 |
+
causal_mask = self.causal_mask[None, None, cache_position]
|
| 361 |
+
else:
|
| 362 |
+
causal_mask = self.causal_mask[None, None, :seq_length, :seq_length]
|
| 363 |
+
|
| 364 |
+
if attention_mask is not None:
|
| 365 |
+
pad_len = causal_mask.shape[3] - attention_mask.shape[1]
|
| 366 |
+
padded_attention_mask = F.pad(attention_mask, (0, pad_len), value=1)
|
| 367 |
+
padded_attention_mask = padded_attention_mask[:, None, None, :]
|
| 368 |
+
padded_attention_mask = padded_attention_mask.bool()
|
| 369 |
+
|
| 370 |
+
causal_mask = causal_mask & padded_attention_mask
|
| 371 |
+
|
| 372 |
kwargs = {
|
| 373 |
"position_ids": position_ids,
|
| 374 |
"past_key_values": past_key_values,
|
|
|
|
| 441 |
attentions=all_attentions,
|
| 442 |
)
|
| 443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
|
| 446 |
"""Transformer decoder with head for language modelling.
|
|
|
|
| 617 |
)
|
| 618 |
|
| 619 |
|
| 620 |
+
def precompute_causal_mask(max_seq_len: int):
|
| 621 |
+
return torch.tril(
|
| 622 |
+
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
|
| 623 |
+
).cuda()
|
| 624 |
+
|
| 625 |
+
|
| 626 |
def precompute_freqs_cis(
|
| 627 |
seq_len: int,
|
| 628 |
n_elem: int,
|
|
|
|
| 640 |
return cache.to(dtype=dtype)
|
| 641 |
|
| 642 |
|
|
|
|
| 643 |
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 644 |
"""
|
| 645 |
In-place RoPE. Credits to Katherine Crowson:
|