fix
Browse files- modeling_aria.py +2 -3
modeling_aria.py
CHANGED
|
@@ -356,8 +356,7 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 356 |
base=500000,
|
| 357 |
dtype=hidden_states.dtype,
|
| 358 |
).to(input_ids.device)
|
| 359 |
-
freqs_cis = self.freqs_cis[
|
| 360 |
-
|
| 361 |
kwargs = {
|
| 362 |
"position_ids": position_ids,
|
| 363 |
"past_key_values": past_key_values,
|
|
@@ -475,7 +474,7 @@ class AriaModel(AriaPreTrainedModel):
|
|
| 475 |
target_length = (
|
| 476 |
attention_mask.shape[-1]
|
| 477 |
if isinstance(attention_mask, torch.Tensor)
|
| 478 |
-
else past_seen_tokens + sequence_length
|
| 479 |
)
|
| 480 |
|
| 481 |
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
|
|
|
| 356 |
base=500000,
|
| 357 |
dtype=hidden_states.dtype,
|
| 358 |
).to(input_ids.device)
|
| 359 |
+
freqs_cis = self.freqs_cis[cache_position]
|
|
|
|
| 360 |
kwargs = {
|
| 361 |
"position_ids": position_ids,
|
| 362 |
"past_key_values": past_key_values,
|
|
|
|
| 474 |
target_length = (
|
| 475 |
attention_mask.shape[-1]
|
| 476 |
if isinstance(attention_mask, torch.Tensor)
|
| 477 |
+
else past_seen_tokens + sequence_length
|
| 478 |
)
|
| 479 |
|
| 480 |
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|