davda54 commited on
Commit
75b4008
·
verified ·
1 Parent(s): df98cf0

Fix causal mode

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +15 -8
modeling_gptbert.py CHANGED
@@ -333,7 +333,7 @@ class MaskedSoftmax(torch.autograd.Function):
333
 
334
 
335
  class SelfAttention(nn.Module):
336
- def __init__(self, config: GptBertConfig, layer_idx: int):
337
  super().__init__()
338
 
339
  self.config = config
@@ -349,6 +349,8 @@ class SelfAttention(nn.Module):
349
  self.k_out_dim = self.d_qk * self.num_kv_heads
350
  self.v_out_dim = self.d_v * self.num_kv_heads
351
 
 
 
352
  self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
353
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
354
  self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
@@ -376,7 +378,6 @@ class SelfAttention(nn.Module):
376
  self.lambdas = nn.Parameter(torch.tensor([0.5]))
377
 
378
  self.sequence_length = config.max_sequence_length
379
- self.is_causal = config.is_decoder
380
  self.window_length = None
381
 
382
  def set_window_length(self, window_length: int):
@@ -526,10 +527,10 @@ class FeedForward(nn.Module):
526
 
527
 
528
  class Layer(nn.Module):
529
- def __init__(self, config: GptBertConfig, layer_idx: int):
530
  super().__init__()
531
 
532
- self.attention = SelfAttention(config, layer_idx)
533
  self.mlp = FeedForward(config)
534
  self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
535
 
@@ -550,9 +551,9 @@ class Layer(nn.Module):
550
 
551
 
552
  class Encoder(nn.Module):
553
- def __init__(self, config: GptBertConfig):
554
  super().__init__()
555
- self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
556
  self.local_global_ratio = config.local_global_ratio
557
 
558
  def set_window_length(self, config: GptBertConfig):
@@ -613,9 +614,10 @@ class GptBertModel(GptBertPreTrainedModel):
613
  super().__init__(config, **kwargs)
614
  self.config = config
615
  self.hidden_size = config.hidden_size
 
616
 
617
  self.embedding = Embedding(config)
618
- self.encoder = Encoder(config)
619
  self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
620
  self.set_window_length(config)
621
  self.gradient_checkpointing = False
@@ -718,6 +720,7 @@ class GptBertForMaskedLM(GptBertModel):
718
  _tied_weights_keys = ["classifier.emb2vocab.weight"]
719
 
720
  def __init__(self, config: GptBertConfig, **kwargs):
 
721
  super().__init__(config, add_mlm_layer=True, **kwargs)
722
 
723
  def get_output_embeddings(self):
@@ -769,7 +772,7 @@ class GptBertForCausalLM(GptBertModel):
769
  _tied_weights_keys = ["classifier.emb2vocab.weight"]
770
 
771
  def __init__(self, config: GptBertConfig, **kwargs):
772
- config.is_decoder = True
773
  super().__init__(config, add_mlm_layer=True, **kwargs)
774
 
775
  def get_output_embeddings(self):
@@ -886,6 +889,7 @@ class GptBertForSequenceClassification(GptBertModel):
886
  _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
887
 
888
  def __init__(self, config: GptBertConfig, **kwargs):
 
889
  super().__init__(config, add_mlm_layer=False, **kwargs)
890
 
891
  self.num_labels = config.num_labels
@@ -941,6 +945,7 @@ class GptBertForTokenClassification(GptBertModel):
941
  _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
942
 
943
  def __init__(self, config: GptBertConfig, **kwargs):
 
944
  super().__init__(config, add_mlm_layer=False, **kwargs)
945
 
946
  self.num_labels = config.num_labels
@@ -978,6 +983,7 @@ class GptBertForQuestionAnswering(GptBertModel):
978
  _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
979
 
980
  def __init__(self, config: GptBertConfig, **kwargs):
 
981
  super().__init__(config, add_mlm_layer=False, **kwargs)
982
 
983
  self.num_labels = config.num_labels
@@ -1034,6 +1040,7 @@ class GptBertForMultipleChoice(GptBertModel):
1034
  _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1035
 
1036
  def __init__(self, config: GptBertConfig, **kwargs):
 
1037
  super().__init__(config, add_mlm_layer=False, **kwargs)
1038
 
1039
  self.num_labels = getattr(config, "num_labels", 2)
 
333
 
334
 
335
  class SelfAttention(nn.Module):
336
+ def __init__(self, config: GptBertConfig, layer_idx: int, is_decoder: bool):
337
  super().__init__()
338
 
339
  self.config = config
 
349
  self.k_out_dim = self.d_qk * self.num_kv_heads
350
  self.v_out_dim = self.d_v * self.num_kv_heads
351
 
352
+ self.is_causal = is_decoder
353
+
354
  self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
355
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
356
  self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
 
378
  self.lambdas = nn.Parameter(torch.tensor([0.5]))
379
 
380
  self.sequence_length = config.max_sequence_length
 
381
  self.window_length = None
382
 
383
  def set_window_length(self, window_length: int):
 
527
 
528
 
529
  class Layer(nn.Module):
530
+ def __init__(self, config: GptBertConfig, layer_idx: int, is_decoder: bool):
531
  super().__init__()
532
 
533
+ self.attention = SelfAttention(config, layer_idx, is_decoder)
534
  self.mlp = FeedForward(config)
535
  self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
536
 
 
551
 
552
 
553
  class Encoder(nn.Module):
554
+ def __init__(self, config: GptBertConfig, is_decoder: bool):
555
  super().__init__()
556
+ self.layers = nn.ModuleList([Layer(config, i, is_decoder) for i in range(config.num_layers)])
557
  self.local_global_ratio = config.local_global_ratio
558
 
559
  def set_window_length(self, config: GptBertConfig):
 
614
  super().__init__(config, **kwargs)
615
  self.config = config
616
  self.hidden_size = config.hidden_size
617
+ self.is_decoder = self.is_decoder if hasattr(self, "is_decoder") else False
618
 
619
  self.embedding = Embedding(config)
620
+ self.encoder = Encoder(config, self.is_decoder)
621
  self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
622
  self.set_window_length(config)
623
  self.gradient_checkpointing = False
 
720
  _tied_weights_keys = ["classifier.emb2vocab.weight"]
721
 
722
  def __init__(self, config: GptBertConfig, **kwargs):
723
+ self.is_decoder = False
724
  super().__init__(config, add_mlm_layer=True, **kwargs)
725
 
726
  def get_output_embeddings(self):
 
772
  _tied_weights_keys = ["classifier.emb2vocab.weight"]
773
 
774
  def __init__(self, config: GptBertConfig, **kwargs):
775
+ self.is_decoder = True
776
  super().__init__(config, add_mlm_layer=True, **kwargs)
777
 
778
  def get_output_embeddings(self):
 
889
  _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
890
 
891
  def __init__(self, config: GptBertConfig, **kwargs):
892
+ self.is_decoder = False
893
  super().__init__(config, add_mlm_layer=False, **kwargs)
894
 
895
  self.num_labels = config.num_labels
 
945
  _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
946
 
947
  def __init__(self, config: GptBertConfig, **kwargs):
948
+ self.is_decoder = False
949
  super().__init__(config, add_mlm_layer=False, **kwargs)
950
 
951
  self.num_labels = config.num_labels
 
983
  _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
984
 
985
  def __init__(self, config: GptBertConfig, **kwargs):
986
+ self.is_decoder = False
987
  super().__init__(config, add_mlm_layer=False, **kwargs)
988
 
989
  self.num_labels = config.num_labels
 
1040
  _keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
1041
 
1042
  def __init__(self, config: GptBertConfig, **kwargs):
1043
+ self.is_decoder = False
1044
  super().__init__(config, add_mlm_layer=False, **kwargs)
1045
 
1046
  self.num_labels = getattr(config, "num_labels", 2)