loua19 commited on
Commit
2c49962
·
1 Parent(s): 13425d0

add max_seq_len checks

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. configuration_aria.py +2 -2
  3. modeling_aria.py +43 -15
config.json CHANGED
@@ -7,7 +7,7 @@
7
  "hidden_size": 1536,
8
  "embedding_size": 512,
9
  "intermediate_size": 6144,
10
- "max_position_embeddings": 8192,
11
  "model_type": "aria",
12
  "num_attention_heads": 24,
13
  "num_hidden_layers": 16,
 
7
  "hidden_size": 1536,
8
  "embedding_size": 512,
9
  "intermediate_size": 6144,
10
+ "max_seq_len": 2048,
11
  "model_type": "aria",
12
  "num_attention_heads": 24,
13
  "num_hidden_layers": 16,
configuration_aria.py CHANGED
@@ -13,7 +13,7 @@ class AriaConfig(PretrainedConfig):
13
  num_hidden_layers: int = 16,
14
  num_attention_heads: int = 64,
15
  intermediate_size: int = 6144,
16
- max_position_embeddings: int = 8192,
17
  use_cache: bool = True,
18
  bos_token_id: int = 0,
19
  eos_token_id: int = 1,
@@ -32,7 +32,7 @@ class AriaConfig(PretrainedConfig):
32
  self.num_hidden_layers = num_hidden_layers
33
  self.num_attention_heads = num_attention_heads
34
  self.intermediate_size = intermediate_size
35
- self.max_position_embeddings = max_position_embeddings
36
  self.use_cache = use_cache
37
  self.tie_word_embeddings = tie_word_embeddings
38
  self.output_attentions = output_attentions
 
13
  num_hidden_layers: int = 16,
14
  num_attention_heads: int = 64,
15
  intermediate_size: int = 6144,
16
+ max_seq_len: int = 8192,
17
  use_cache: bool = True,
18
  bos_token_id: int = 0,
19
  eos_token_id: int = 1,
 
32
  self.num_hidden_layers = num_hidden_layers
33
  self.num_attention_heads = num_attention_heads
34
  self.intermediate_size = intermediate_size
35
+ self.max_seq_len = max_seq_len
36
  self.use_cache = use_cache
37
  self.tie_word_embeddings = tie_word_embeddings
38
  self.output_attentions = output_attentions
modeling_aria.py CHANGED
@@ -66,7 +66,7 @@ class TransformerBlock(nn.Module):
66
  self.d_head = (
67
  model_config.hidden_size // model_config.num_attention_heads
68
  )
69
- self.max_seq_len = model_config.max_position_embeddings
70
  self.layer_idx = layer_idx
71
 
72
  # Attention
@@ -257,6 +257,23 @@ class AriaModel(AriaPreTrainedModel):
257
  torch.tensor: Model outputs with shape (batch_size, seq_len,
258
  d_model).
259
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  output_attentions = (
261
  output_attentions
262
  if output_attentions is not None
@@ -333,7 +350,7 @@ class AriaModel(AriaPreTrainedModel):
333
 
334
  if self.freqs_cis is None:
335
  self.freqs_cis = precompute_freqs_cis(
336
- seq_len=self.model_config.max_position_embeddings,
337
  n_elem=self.model_config.hidden_size
338
  // self.model_config.num_attention_heads,
339
  base=500000,
@@ -548,7 +565,7 @@ class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
548
  def __init__(self, model_config: AriaConfig):
549
  super().__init__(model_config)
550
  self.model_config = model_config
551
- self.max_seq_len = model_config.max_position_embeddings
552
  self.model = AriaModel(model_config)
553
  self.lm_head = nn.Linear(
554
  model_config.hidden_size, model_config.vocab_size, bias=False
@@ -629,13 +646,30 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
629
  assert model_config.embedding_size
630
 
631
  self.model_config = model_config
632
- self.max_seq_len = model_config.max_position_embeddings
633
  self.model = AriaModel(model_config)
634
  self.emb_head = nn.Linear(
635
  model_config.hidden_size, model_config.embedding_size, bias=False
636
  )
637
  self.post_init()
638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  def forward(
640
  self,
641
  input_ids: torch.Tensor,
@@ -671,14 +705,6 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
671
  ):
672
  raise ValueError("Provided args unsupported for embedding head")
673
 
674
- _batch_size = input_ids.shape[0]
675
- eos_mask = input_ids == self.config.eos_token_id
676
- if not eos_mask.any(dim=1).all():
677
- raise ValueError(
678
- "Each sequence must contain at least one EOS token"
679
- )
680
- eos_pos = eos_mask.int().argmax(dim=1)
681
-
682
  outputs = self.model(
683
  input_ids,
684
  attention_mask=attention_mask,
@@ -689,9 +715,11 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
689
  )
690
  hidden = outputs[0]
691
  embedding = self.emb_head(hidden)
692
- pooled_embedding = embedding[
693
- torch.arange(_batch_size, device=input_ids.device), eos_pos
694
- ]
 
 
695
  if not return_dict:
696
  output = (pooled_embedding,) + outputs[1:]
697
  return output
 
66
  self.d_head = (
67
  model_config.hidden_size // model_config.num_attention_heads
68
  )
69
+ self.max_seq_len = model_config.max_seq_len
70
  self.layer_idx = layer_idx
71
 
72
  # Attention
 
257
  torch.tensor: Model outputs with shape (batch_size, seq_len,
258
  d_model).
259
  """
260
+ if (
261
+ input_ids is not None
262
+ and input_ids.shape[1] > self.model_config.max_seq_len
263
+ ):
264
+ raise ValueError(
265
+ f"Sequence length ({input_ids.shape[1]}) exceeds max_seq_len "
266
+ f"({self.model_config.max_seq_len})."
267
+ )
268
+ if (
269
+ inputs_embeds is not None
270
+ and inputs_embeds.shape[1] > self.model_config.max_seq_len
271
+ ):
272
+ raise ValueError(
273
+ f"Sequence length ({inputs_embeds.shape[1]}) exceeds max_seq_len "
274
+ f"({self.model_config.max_seq_len})."
275
+ )
276
+
277
  output_attentions = (
278
  output_attentions
279
  if output_attentions is not None
 
350
 
351
  if self.freqs_cis is None:
352
  self.freqs_cis = precompute_freqs_cis(
353
+ seq_len=self.model_config.max_seq_len,
354
  n_elem=self.model_config.hidden_size
355
  // self.model_config.num_attention_heads,
356
  base=500000,
 
565
  def __init__(self, model_config: AriaConfig):
566
  super().__init__(model_config)
567
  self.model_config = model_config
568
+ self.max_seq_len = model_config.max_seq_len
569
  self.model = AriaModel(model_config)
570
  self.lm_head = nn.Linear(
571
  model_config.hidden_size, model_config.vocab_size, bias=False
 
646
  assert model_config.embedding_size
647
 
648
  self.model_config = model_config
649
+ self.max_seq_len = model_config.max_seq_len
650
  self.model = AriaModel(model_config)
651
  self.emb_head = nn.Linear(
652
  model_config.hidden_size, model_config.embedding_size, bias=False
653
  )
654
  self.post_init()
655
 
656
+ def get_pooled_embedding(
657
+ self, input_ids: torch.Tensor, embedding: torch.Tensor
658
+ ):
659
+ _batch_size = input_ids.shape[0]
660
+ eos_mask = input_ids == self.config.eos_token_id
661
+ if not eos_mask.any(dim=1).all():
662
+ raise ValueError(
663
+ "Each sequence must contain at least one EOS token"
664
+ )
665
+ eos_pos = eos_mask.int().argmax(dim=1)
666
+
667
+ pooled_embedding = embedding[
668
+ torch.arange(_batch_size, device=input_ids.device), eos_pos
669
+ ]
670
+
671
+ return pooled_embedding
672
+
673
  def forward(
674
  self,
675
  input_ids: torch.Tensor,
 
705
  ):
706
  raise ValueError("Provided args unsupported for embedding head")
707
 
 
 
 
 
 
 
 
 
708
  outputs = self.model(
709
  input_ids,
710
  attention_mask=attention_mask,
 
715
  )
716
  hidden = outputs[0]
717
  embedding = self.emb_head(hidden)
718
+ pooled_embedding = self.get_pooled_embedding(
719
+ input_ids=input_ids,
720
+ embedding=embedding,
721
+ )
722
+
723
  if not return_dict:
724
  output = (pooled_embedding,) + outputs[1:]
725
  return output