Fix grad checkpoint and outputs param
Browse files
src/axolotl/monkeypatch/llama_landmark_attn.py
CHANGED
|
@@ -27,7 +27,6 @@ from typing import List, Optional, Tuple, Union
|
|
| 27 |
|
| 28 |
import torch
|
| 29 |
import torch.utils.checkpoint
|
| 30 |
-
import transformers
|
| 31 |
from torch import nn
|
| 32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 33 |
from transformers.activations import ACT2FN
|
|
@@ -52,10 +51,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
|
|
| 52 |
MEM_TOKEN = "<landmark>" # nosec
|
| 53 |
|
| 54 |
|
| 55 |
-
def hijack_llama_landmark_attn():
|
| 56 |
-
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
| 57 |
-
|
| 58 |
-
|
| 59 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 60 |
def _make_causal_mask(
|
| 61 |
input_ids_shape: torch.Size,
|
|
@@ -1125,7 +1120,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 1125 |
def create_custom_forward(module):
|
| 1126 |
def custom_forward(*inputs):
|
| 1127 |
# None for past_key_value
|
| 1128 |
-
return module(*inputs
|
| 1129 |
|
| 1130 |
return custom_forward
|
| 1131 |
|
|
@@ -1135,6 +1130,8 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 1135 |
attention_mask,
|
| 1136 |
position_ids,
|
| 1137 |
None,
|
|
|
|
|
|
|
| 1138 |
is_mem,
|
| 1139 |
last_section_mask,
|
| 1140 |
)
|
|
@@ -1300,7 +1297,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
| 1300 |
return_dict=return_dict,
|
| 1301 |
offload_cache_to_cpu=offload_cache_to_cpu,
|
| 1302 |
)
|
| 1303 |
-
past_key_values = outputs
|
| 1304 |
if last_logits is not None:
|
| 1305 |
last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
|
| 1306 |
last_logits = outputs[0]
|
|
|
|
| 27 |
|
| 28 |
import torch
|
| 29 |
import torch.utils.checkpoint
|
|
|
|
| 30 |
from torch import nn
|
| 31 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 32 |
from transformers.activations import ACT2FN
|
|
|
|
| 51 |
MEM_TOKEN = "<landmark>" # nosec
|
| 52 |
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 55 |
def _make_causal_mask(
|
| 56 |
input_ids_shape: torch.Size,
|
|
|
|
| 1120 |
def create_custom_forward(module):
|
| 1121 |
def custom_forward(*inputs):
|
| 1122 |
# None for past_key_value
|
| 1123 |
+
return module(*inputs)
|
| 1124 |
|
| 1125 |
return custom_forward
|
| 1126 |
|
|
|
|
| 1130 |
attention_mask,
|
| 1131 |
position_ids,
|
| 1132 |
None,
|
| 1133 |
+
output_attentions,
|
| 1134 |
+
None,
|
| 1135 |
is_mem,
|
| 1136 |
last_section_mask,
|
| 1137 |
)
|
|
|
|
| 1297 |
return_dict=return_dict,
|
| 1298 |
offload_cache_to_cpu=offload_cache_to_cpu,
|
| 1299 |
)
|
| 1300 |
+
past_key_values = outputs.past_key_values
|
| 1301 |
if last_logits is not None:
|
| 1302 |
last_logits = torch.cat((last_logits, outputs[0]), dim=-2)
|
| 1303 |
last_logits = outputs[0]
|