Update modeling_custom_llama.py
Browse files- modeling_custom_llama.py +2 -23
modeling_custom_llama.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|
|
|
| 2 |
from transformers import PretrainedConfig
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 4 |
from transformers import GPT2TokenizerFast
|
|
@@ -122,26 +123,6 @@ class CustomLlamaAttention(LlamaAttention):
|
|
| 122 |
|
| 123 |
attn_scores = self._compute_metric(q_start, q_dir, k_start, k_dir)
|
| 124 |
|
| 125 |
-
# # Handle attention mask and causality
|
| 126 |
-
# if attention_mask is not None:
|
| 127 |
-
# # Convert padding mask [batch_size, seq_len] to [batch_size, 1, 1, seq_len]
|
| 128 |
-
# padding_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 129 |
-
# padding_mask = (1.0 - padding_mask) * torch.finfo(attn_scores.dtype).min
|
| 130 |
-
# if is_causal is not None:
|
| 131 |
-
# causal_mask = self._get_causal_mask(seq_len, seq_len, attn_scores.device)
|
| 132 |
-
# is_causal_expanded = is_causal.view(-1, 1, 1, 1)
|
| 133 |
-
# attention_mask = padding_mask + (causal_mask * is_causal_expanded)
|
| 134 |
-
# else:
|
| 135 |
-
# attention_mask = padding_mask
|
| 136 |
-
# else:
|
| 137 |
-
# if is_causal is not None:
|
| 138 |
-
# causal_mask = self._get_causal_mask(seq_len, seq_len, attn_scores.device)
|
| 139 |
-
# is_causal_expanded = is_causal.view(-1, 1, 1, 1)
|
| 140 |
-
# attention_mask = causal_mask * is_causal_expanded
|
| 141 |
-
# else:
|
| 142 |
-
# attention_mask = torch.zeros_like(attn_scores)
|
| 143 |
-
|
| 144 |
-
# attn_scores = attn_scores + attention_mask
|
| 145 |
# Replace existing mask logic with:
|
| 146 |
if attention_mask is not None:
|
| 147 |
padding_mask = (attention_mask == 0).view(batch_size, 1, 1, -1)
|
|
@@ -324,8 +305,6 @@ class CustomLlamaForCausalLM(LlamaForCausalLM):
|
|
| 324 |
)
|
| 325 |
|
| 326 |
return ModelOutput(loss=loss, logits=logits)
|
| 327 |
-
# return {"loss": loss, "logits": logits}
|
| 328 |
-
# return {"loss": loss, "logits": logits} if return_dict else (loss, logits)
|
| 329 |
|
| 330 |
class CustomLlamaForMaskedLM(CustomLlamaForCausalLM):
|
| 331 |
config_class = CustomLlamaConfig # Add this line
|
|
@@ -369,6 +348,6 @@ MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM})
|
|
| 369 |
def _register():
|
| 370 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 371 |
AutoConfig.register("custom_llama", CustomLlamaConfig)
|
| 372 |
-
|
| 373 |
|
| 374 |
_register()
|
|
|
|
| 1 |
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
| 2 |
+
from transformers import MODEL_FOR_MASKED_LM_MAPPING
|
| 3 |
from transformers import PretrainedConfig
|
| 4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 5 |
from transformers import GPT2TokenizerFast
|
|
|
|
| 123 |
|
| 124 |
attn_scores = self._compute_metric(q_start, q_dir, k_start, k_dir)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# Replace existing mask logic with:
|
| 127 |
if attention_mask is not None:
|
| 128 |
padding_mask = (attention_mask == 0).view(batch_size, 1, 1, -1)
|
|
|
|
| 305 |
)
|
| 306 |
|
| 307 |
return ModelOutput(loss=loss, logits=logits)
|
|
|
|
|
|
|
| 308 |
|
| 309 |
class CustomLlamaForMaskedLM(CustomLlamaForCausalLM):
|
| 310 |
config_class = CustomLlamaConfig # Add this line
|
|
|
|
| 348 |
def _register():
|
| 349 |
from transformers import AutoConfig, AutoModelForCausalLM
|
| 350 |
AutoConfig.register("custom_llama", CustomLlamaConfig)
|
| 351 |
+
MODEL_FOR_MASKED_LM_MAPPING.register(CustomLlamaConfig, CustomLlamaForMaskedLM)
|
| 352 |
|
| 353 |
_register()
|