Shounak commited on
Commit
a67e8a1
·
verified ·
1 Parent(s): 73009d0

Update modeling_custom_llama.py

Browse files
Files changed (1) hide show
  1. modeling_custom_llama.py +2 -23
modeling_custom_llama.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers.models.llama.modeling_llama import LlamaForCausalLM
 
2
  from transformers import PretrainedConfig
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers import GPT2TokenizerFast
@@ -122,26 +123,6 @@ class CustomLlamaAttention(LlamaAttention):
122
 
123
  attn_scores = self._compute_metric(q_start, q_dir, k_start, k_dir)
124
 
125
- # # Handle attention mask and causality
126
- # if attention_mask is not None:
127
- # # Convert padding mask [batch_size, seq_len] to [batch_size, 1, 1, seq_len]
128
- # padding_mask = attention_mask.unsqueeze(1).unsqueeze(2)
129
- # padding_mask = (1.0 - padding_mask) * torch.finfo(attn_scores.dtype).min
130
- # if is_causal is not None:
131
- # causal_mask = self._get_causal_mask(seq_len, seq_len, attn_scores.device)
132
- # is_causal_expanded = is_causal.view(-1, 1, 1, 1)
133
- # attention_mask = padding_mask + (causal_mask * is_causal_expanded)
134
- # else:
135
- # attention_mask = padding_mask
136
- # else:
137
- # if is_causal is not None:
138
- # causal_mask = self._get_causal_mask(seq_len, seq_len, attn_scores.device)
139
- # is_causal_expanded = is_causal.view(-1, 1, 1, 1)
140
- # attention_mask = causal_mask * is_causal_expanded
141
- # else:
142
- # attention_mask = torch.zeros_like(attn_scores)
143
-
144
- # attn_scores = attn_scores + attention_mask
145
  # Replace existing mask logic with:
146
  if attention_mask is not None:
147
  padding_mask = (attention_mask == 0).view(batch_size, 1, 1, -1)
@@ -324,8 +305,6 @@ class CustomLlamaForCausalLM(LlamaForCausalLM):
324
  )
325
 
326
  return ModelOutput(loss=loss, logits=logits)
327
- # return {"loss": loss, "logits": logits}
328
- # return {"loss": loss, "logits": logits} if return_dict else (loss, logits)
329
 
330
  class CustomLlamaForMaskedLM(CustomLlamaForCausalLM):
331
  config_class = CustomLlamaConfig # Add this line
@@ -369,6 +348,6 @@ MODEL_MAPPING.update({"custom_llama": CustomLlamaForMaskedLM})
369
  def _register():
370
  from transformers import AutoConfig, AutoModelForCausalLM
371
  AutoConfig.register("custom_llama", CustomLlamaConfig)
372
- # AutoModelForCausalLM.register(CustomLlamaConfig, CustomLlamaForCausalLM)
373
 
374
  _register()
 
1
  from transformers.models.llama.modeling_llama import LlamaForCausalLM
2
+ from transformers import MODEL_FOR_MASKED_LM_MAPPING
3
  from transformers import PretrainedConfig
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
  from transformers import GPT2TokenizerFast
 
123
 
124
  attn_scores = self._compute_metric(q_start, q_dir, k_start, k_dir)
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  # Replace existing mask logic with:
127
  if attention_mask is not None:
128
  padding_mask = (attention_mask == 0).view(batch_size, 1, 1, -1)
 
305
  )
306
 
307
  return ModelOutput(loss=loss, logits=logits)
 
 
308
 
309
  class CustomLlamaForMaskedLM(CustomLlamaForCausalLM):
310
  config_class = CustomLlamaConfig # Add this line
 
348
  def _register():
349
  from transformers import AutoConfig, AutoModelForCausalLM
350
  AutoConfig.register("custom_llama", CustomLlamaConfig)
351
+ MODEL_FOR_MASKED_LM_MAPPING.register(CustomLlamaConfig, CustomLlamaForMaskedLM)
352
 
353
  _register()