loua19 commited on
Commit
57bb396
·
1 Parent(s): fdd8ec9
Files changed (1) hide show
  1. modeling_aria.py +2 -3
modeling_aria.py CHANGED
@@ -356,8 +356,7 @@ class AriaModel(AriaPreTrainedModel):
356
  base=500000,
357
  dtype=hidden_states.dtype,
358
  ).to(input_ids.device)
359
- freqs_cis = self.freqs_cis[: input_ids.shape[1]]
360
-
361
  kwargs = {
362
  "position_ids": position_ids,
363
  "past_key_values": past_key_values,
@@ -475,7 +474,7 @@ class AriaModel(AriaPreTrainedModel):
475
  target_length = (
476
  attention_mask.shape[-1]
477
  if isinstance(attention_mask, torch.Tensor)
478
- else past_seen_tokens + sequence_length + 1
479
  )
480
 
481
  # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
 
356
  base=500000,
357
  dtype=hidden_states.dtype,
358
  ).to(input_ids.device)
359
+ freqs_cis = self.freqs_cis[cache_position]
 
360
  kwargs = {
361
  "position_ids": position_ids,
362
  "past_key_values": past_key_values,
 
474
  target_length = (
475
  attention_mask.shape[-1]
476
  if isinstance(attention_mask, torch.Tensor)
477
+ else past_seen_tokens + sequence_length
478
  )
479
 
480
  # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).