add max_seq_len checks
Browse files- config.json +1 -1
- configuration_aria.py +2 -2
- modeling_aria.py +43 -15
config.json
CHANGED
|
@@ -7,7 +7,7 @@
|
|
| 7 |
"hidden_size": 1536,
|
| 8 |
"embedding_size": 512,
|
| 9 |
"intermediate_size": 6144,
|
| 10 |
-
"
|
| 11 |
"model_type": "aria",
|
| 12 |
"num_attention_heads": 24,
|
| 13 |
"num_hidden_layers": 16,
|
|
|
|
| 7 |
"hidden_size": 1536,
|
| 8 |
"embedding_size": 512,
|
| 9 |
"intermediate_size": 6144,
|
| 10 |
+
"max_seq_len": 2048,
|
| 11 |
"model_type": "aria",
|
| 12 |
"num_attention_heads": 24,
|
| 13 |
"num_hidden_layers": 16,
|
configuration_aria.py
CHANGED
|
@@ -13,7 +13,7 @@ class AriaConfig(PretrainedConfig):
|
|
| 13 |
num_hidden_layers: int = 16,
|
| 14 |
num_attention_heads: int = 64,
|
| 15 |
intermediate_size: int = 6144,
|
| 16 |
-
|
| 17 |
use_cache: bool = True,
|
| 18 |
bos_token_id: int = 0,
|
| 19 |
eos_token_id: int = 1,
|
|
@@ -32,7 +32,7 @@ class AriaConfig(PretrainedConfig):
|
|
| 32 |
self.num_hidden_layers = num_hidden_layers
|
| 33 |
self.num_attention_heads = num_attention_heads
|
| 34 |
self.intermediate_size = intermediate_size
|
| 35 |
-
self.
|
| 36 |
self.use_cache = use_cache
|
| 37 |
self.tie_word_embeddings = tie_word_embeddings
|
| 38 |
self.output_attentions = output_attentions
|
|
|
|
| 13 |
num_hidden_layers: int = 16,
|
| 14 |
num_attention_heads: int = 64,
|
| 15 |
intermediate_size: int = 6144,
|
| 16 |
+
max_seq_len: int = 8192,
|
| 17 |
use_cache: bool = True,
|
| 18 |
bos_token_id: int = 0,
|
| 19 |
eos_token_id: int = 1,
|
|
|
|
| 32 |
self.num_hidden_layers = num_hidden_layers
|
| 33 |
self.num_attention_heads = num_attention_heads
|
| 34 |
self.intermediate_size = intermediate_size
|
| 35 |
+
self.max_seq_len = max_seq_len
|
| 36 |
self.use_cache = use_cache
|
| 37 |
self.tie_word_embeddings = tie_word_embeddings
|
| 38 |
self.output_attentions = output_attentions
|
modeling_aria.py
CHANGED
|
@@ -66,7 +66,7 @@ class TransformerBlock(nn.Module):
|
|
| 66 |
self.d_head = (
|
| 67 |
model_config.hidden_size // model_config.num_attention_heads
|
| 68 |
)
|
| 69 |
-
self.max_seq_len = model_config.
|
| 70 |
self.layer_idx = layer_idx
|
| 71 |
|
| 72 |
# Attention
|
|
@@ -257,6 +257,23 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 257 |
torch.tensor: Model outputs with shape (batch_size, seq_len,
|
| 258 |
d_model).
|
| 259 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
output_attentions = (
|
| 261 |
output_attentions
|
| 262 |
if output_attentions is not None
|
|
@@ -333,7 +350,7 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 333 |
|
| 334 |
if self.freqs_cis is None:
|
| 335 |
self.freqs_cis = precompute_freqs_cis(
|
| 336 |
-
seq_len=self.model_config.
|
| 337 |
n_elem=self.model_config.hidden_size
|
| 338 |
// self.model_config.num_attention_heads,
|
| 339 |
base=500000,
|
|
@@ -548,7 +565,7 @@ class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
|
|
| 548 |
def __init__(self, model_config: AriaConfig):
|
| 549 |
super().__init__(model_config)
|
| 550 |
self.model_config = model_config
|
| 551 |
-
self.max_seq_len = model_config.
|
| 552 |
self.model = AriaModel(model_config)
|
| 553 |
self.lm_head = nn.Linear(
|
| 554 |
model_config.hidden_size, model_config.vocab_size, bias=False
|
|
@@ -629,13 +646,30 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
|
|
| 629 |
assert model_config.embedding_size
|
| 630 |
|
| 631 |
self.model_config = model_config
|
| 632 |
-
self.max_seq_len = model_config.
|
| 633 |
self.model = AriaModel(model_config)
|
| 634 |
self.emb_head = nn.Linear(
|
| 635 |
model_config.hidden_size, model_config.embedding_size, bias=False
|
| 636 |
)
|
| 637 |
self.post_init()
|
| 638 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
def forward(
|
| 640 |
self,
|
| 641 |
input_ids: torch.Tensor,
|
|
@@ -671,14 +705,6 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
|
|
| 671 |
):
|
| 672 |
raise ValueError("Provided args unsupported for embedding head")
|
| 673 |
|
| 674 |
-
_batch_size = input_ids.shape[0]
|
| 675 |
-
eos_mask = input_ids == self.config.eos_token_id
|
| 676 |
-
if not eos_mask.any(dim=1).all():
|
| 677 |
-
raise ValueError(
|
| 678 |
-
"Each sequence must contain at least one EOS token"
|
| 679 |
-
)
|
| 680 |
-
eos_pos = eos_mask.int().argmax(dim=1)
|
| 681 |
-
|
| 682 |
outputs = self.model(
|
| 683 |
input_ids,
|
| 684 |
attention_mask=attention_mask,
|
|
@@ -689,9 +715,11 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
|
|
| 689 |
)
|
| 690 |
hidden = outputs[0]
|
| 691 |
embedding = self.emb_head(hidden)
|
| 692 |
-
pooled_embedding =
|
| 693 |
-
|
| 694 |
-
|
|
|
|
|
|
|
| 695 |
if not return_dict:
|
| 696 |
output = (pooled_embedding,) + outputs[1:]
|
| 697 |
return output
|
|
|
|
| 66 |
self.d_head = (
|
| 67 |
model_config.hidden_size // model_config.num_attention_heads
|
| 68 |
)
|
| 69 |
+
self.max_seq_len = model_config.max_seq_len
|
| 70 |
self.layer_idx = layer_idx
|
| 71 |
|
| 72 |
# Attention
|
|
|
|
| 257 |
torch.tensor: Model outputs with shape (batch_size, seq_len,
|
| 258 |
d_model).
|
| 259 |
"""
|
| 260 |
+
if (
|
| 261 |
+
input_ids is not None
|
| 262 |
+
and input_ids.shape[1] > self.model_config.max_seq_len
|
| 263 |
+
):
|
| 264 |
+
raise ValueError(
|
| 265 |
+
f"Sequence length ({input_ids.shape[1]}) exceeds max_seq_len "
|
| 266 |
+
f"({self.model_config.max_seq_len})."
|
| 267 |
+
)
|
| 268 |
+
if (
|
| 269 |
+
inputs_embeds is not None
|
| 270 |
+
and inputs_embeds.shape[1] > self.model_config.max_seq_len
|
| 271 |
+
):
|
| 272 |
+
raise ValueError(
|
| 273 |
+
f"Sequence length ({inputs_embeds.shape[1]}) exceeds max_seq_len "
|
| 274 |
+
f"({self.model_config.max_seq_len})."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
output_attentions = (
|
| 278 |
output_attentions
|
| 279 |
if output_attentions is not None
|
|
|
|
| 350 |
|
| 351 |
if self.freqs_cis is None:
|
| 352 |
self.freqs_cis = precompute_freqs_cis(
|
| 353 |
+
seq_len=self.model_config.max_seq_len,
|
| 354 |
n_elem=self.model_config.hidden_size
|
| 355 |
// self.model_config.num_attention_heads,
|
| 356 |
base=500000,
|
|
|
|
| 565 |
def __init__(self, model_config: AriaConfig):
|
| 566 |
super().__init__(model_config)
|
| 567 |
self.model_config = model_config
|
| 568 |
+
self.max_seq_len = model_config.max_seq_len
|
| 569 |
self.model = AriaModel(model_config)
|
| 570 |
self.lm_head = nn.Linear(
|
| 571 |
model_config.hidden_size, model_config.vocab_size, bias=False
|
|
|
|
| 646 |
assert model_config.embedding_size
|
| 647 |
|
| 648 |
self.model_config = model_config
|
| 649 |
+
self.max_seq_len = model_config.max_seq_len
|
| 650 |
self.model = AriaModel(model_config)
|
| 651 |
self.emb_head = nn.Linear(
|
| 652 |
model_config.hidden_size, model_config.embedding_size, bias=False
|
| 653 |
)
|
| 654 |
self.post_init()
|
| 655 |
|
| 656 |
+
def get_pooled_embedding(
|
| 657 |
+
self, input_ids: torch.Tensor, embedding: torch.Tensor
|
| 658 |
+
):
|
| 659 |
+
_batch_size = input_ids.shape[0]
|
| 660 |
+
eos_mask = input_ids == self.config.eos_token_id
|
| 661 |
+
if not eos_mask.any(dim=1).all():
|
| 662 |
+
raise ValueError(
|
| 663 |
+
"Each sequence must contain at least one EOS token"
|
| 664 |
+
)
|
| 665 |
+
eos_pos = eos_mask.int().argmax(dim=1)
|
| 666 |
+
|
| 667 |
+
pooled_embedding = embedding[
|
| 668 |
+
torch.arange(_batch_size, device=input_ids.device), eos_pos
|
| 669 |
+
]
|
| 670 |
+
|
| 671 |
+
return pooled_embedding
|
| 672 |
+
|
| 673 |
def forward(
|
| 674 |
self,
|
| 675 |
input_ids: torch.Tensor,
|
|
|
|
| 705 |
):
|
| 706 |
raise ValueError("Provided args unsupported for embedding head")
|
| 707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
outputs = self.model(
|
| 709 |
input_ids,
|
| 710 |
attention_mask=attention_mask,
|
|
|
|
| 715 |
)
|
| 716 |
hidden = outputs[0]
|
| 717 |
embedding = self.emb_head(hidden)
|
| 718 |
+
pooled_embedding = self.get_pooled_embedding(
|
| 719 |
+
input_ids=input_ids,
|
| 720 |
+
embedding=embedding,
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
if not return_dict:
|
| 724 |
output = (pooled_embedding,) + outputs[1:]
|
| 725 |
return output
|