Fix causal mode
Browse files- modeling_gptbert.py +15 -8
modeling_gptbert.py
CHANGED
|
@@ -333,7 +333,7 @@ class MaskedSoftmax(torch.autograd.Function):
|
|
| 333 |
|
| 334 |
|
| 335 |
class SelfAttention(nn.Module):
|
| 336 |
-
def __init__(self, config: GptBertConfig, layer_idx: int):
|
| 337 |
super().__init__()
|
| 338 |
|
| 339 |
self.config = config
|
|
@@ -349,6 +349,8 @@ class SelfAttention(nn.Module):
|
|
| 349 |
self.k_out_dim = self.d_qk * self.num_kv_heads
|
| 350 |
self.v_out_dim = self.d_v * self.num_kv_heads
|
| 351 |
|
|
|
|
|
|
|
| 352 |
self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
|
| 353 |
self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
|
| 354 |
self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
|
|
@@ -376,7 +378,6 @@ class SelfAttention(nn.Module):
|
|
| 376 |
self.lambdas = nn.Parameter(torch.tensor([0.5]))
|
| 377 |
|
| 378 |
self.sequence_length = config.max_sequence_length
|
| 379 |
-
self.is_causal = config.is_decoder
|
| 380 |
self.window_length = None
|
| 381 |
|
| 382 |
def set_window_length(self, window_length: int):
|
|
@@ -526,10 +527,10 @@ class FeedForward(nn.Module):
|
|
| 526 |
|
| 527 |
|
| 528 |
class Layer(nn.Module):
|
| 529 |
-
def __init__(self, config: GptBertConfig, layer_idx: int):
|
| 530 |
super().__init__()
|
| 531 |
|
| 532 |
-
self.attention = SelfAttention(config, layer_idx)
|
| 533 |
self.mlp = FeedForward(config)
|
| 534 |
self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
|
| 535 |
|
|
@@ -550,9 +551,9 @@ class Layer(nn.Module):
|
|
| 550 |
|
| 551 |
|
| 552 |
class Encoder(nn.Module):
|
| 553 |
-
def __init__(self, config: GptBertConfig):
|
| 554 |
super().__init__()
|
| 555 |
-
self.layers = nn.ModuleList([Layer(config, i) for i in range(config.num_layers)])
|
| 556 |
self.local_global_ratio = config.local_global_ratio
|
| 557 |
|
| 558 |
def set_window_length(self, config: GptBertConfig):
|
|
@@ -613,9 +614,10 @@ class GptBertModel(GptBertPreTrainedModel):
|
|
| 613 |
super().__init__(config, **kwargs)
|
| 614 |
self.config = config
|
| 615 |
self.hidden_size = config.hidden_size
|
|
|
|
| 616 |
|
| 617 |
self.embedding = Embedding(config)
|
| 618 |
-
self.encoder = Encoder(config)
|
| 619 |
self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
|
| 620 |
self.set_window_length(config)
|
| 621 |
self.gradient_checkpointing = False
|
|
@@ -718,6 +720,7 @@ class GptBertForMaskedLM(GptBertModel):
|
|
| 718 |
_tied_weights_keys = ["classifier.emb2vocab.weight"]
|
| 719 |
|
| 720 |
def __init__(self, config: GptBertConfig, **kwargs):
|
|
|
|
| 721 |
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 722 |
|
| 723 |
def get_output_embeddings(self):
|
|
@@ -769,7 +772,7 @@ class GptBertForCausalLM(GptBertModel):
|
|
| 769 |
_tied_weights_keys = ["classifier.emb2vocab.weight"]
|
| 770 |
|
| 771 |
def __init__(self, config: GptBertConfig, **kwargs):
|
| 772 |
-
|
| 773 |
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 774 |
|
| 775 |
def get_output_embeddings(self):
|
|
@@ -886,6 +889,7 @@ class GptBertForSequenceClassification(GptBertModel):
|
|
| 886 |
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 887 |
|
| 888 |
def __init__(self, config: GptBertConfig, **kwargs):
|
|
|
|
| 889 |
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 890 |
|
| 891 |
self.num_labels = config.num_labels
|
|
@@ -941,6 +945,7 @@ class GptBertForTokenClassification(GptBertModel):
|
|
| 941 |
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 942 |
|
| 943 |
def __init__(self, config: GptBertConfig, **kwargs):
|
|
|
|
| 944 |
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 945 |
|
| 946 |
self.num_labels = config.num_labels
|
|
@@ -978,6 +983,7 @@ class GptBertForQuestionAnswering(GptBertModel):
|
|
| 978 |
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 979 |
|
| 980 |
def __init__(self, config: GptBertConfig, **kwargs):
|
|
|
|
| 981 |
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 982 |
|
| 983 |
self.num_labels = config.num_labels
|
|
@@ -1034,6 +1040,7 @@ class GptBertForMultipleChoice(GptBertModel):
|
|
| 1034 |
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 1035 |
|
| 1036 |
def __init__(self, config: GptBertConfig, **kwargs):
|
|
|
|
| 1037 |
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 1038 |
|
| 1039 |
self.num_labels = getattr(config, "num_labels", 2)
|
|
|
|
| 333 |
|
| 334 |
|
| 335 |
class SelfAttention(nn.Module):
|
| 336 |
+
def __init__(self, config: GptBertConfig, layer_idx: int, is_decoder: bool):
|
| 337 |
super().__init__()
|
| 338 |
|
| 339 |
self.config = config
|
|
|
|
| 349 |
self.k_out_dim = self.d_qk * self.num_kv_heads
|
| 350 |
self.v_out_dim = self.d_v * self.num_kv_heads
|
| 351 |
|
| 352 |
+
self.is_causal = is_decoder
|
| 353 |
+
|
| 354 |
self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
|
| 355 |
self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
|
| 356 |
self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
|
|
|
|
| 378 |
self.lambdas = nn.Parameter(torch.tensor([0.5]))
|
| 379 |
|
| 380 |
self.sequence_length = config.max_sequence_length
|
|
|
|
| 381 |
self.window_length = None
|
| 382 |
|
| 383 |
def set_window_length(self, window_length: int):
|
|
|
|
| 527 |
|
| 528 |
|
| 529 |
class Layer(nn.Module):
|
| 530 |
+
def __init__(self, config: GptBertConfig, layer_idx: int, is_decoder: bool):
|
| 531 |
super().__init__()
|
| 532 |
|
| 533 |
+
self.attention = SelfAttention(config, layer_idx, is_decoder)
|
| 534 |
self.mlp = FeedForward(config)
|
| 535 |
self.lambdas = nn.Parameter(torch.tensor([0., 0., 1., 0., 1., 0.]))
|
| 536 |
|
|
|
|
| 551 |
|
| 552 |
|
| 553 |
class Encoder(nn.Module):
|
| 554 |
+
def __init__(self, config: GptBertConfig, is_decoder: bool):
|
| 555 |
super().__init__()
|
| 556 |
+
self.layers = nn.ModuleList([Layer(config, i, is_decoder) for i in range(config.num_layers)])
|
| 557 |
self.local_global_ratio = config.local_global_ratio
|
| 558 |
|
| 559 |
def set_window_length(self, config: GptBertConfig):
|
|
|
|
| 614 |
super().__init__(config, **kwargs)
|
| 615 |
self.config = config
|
| 616 |
self.hidden_size = config.hidden_size
|
| 617 |
+
self.is_decoder = self.is_decoder if hasattr(self, "is_decoder") else False
|
| 618 |
|
| 619 |
self.embedding = Embedding(config)
|
| 620 |
+
self.encoder = Encoder(config, self.is_decoder)
|
| 621 |
self.classifier = LMClassifier(config, config.vocab_size) if add_mlm_layer else None
|
| 622 |
self.set_window_length(config)
|
| 623 |
self.gradient_checkpointing = False
|
|
|
|
| 720 |
_tied_weights_keys = ["classifier.emb2vocab.weight"]
|
| 721 |
|
| 722 |
def __init__(self, config: GptBertConfig, **kwargs):
|
| 723 |
+
self.is_decoder = False
|
| 724 |
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 725 |
|
| 726 |
def get_output_embeddings(self):
|
|
|
|
| 772 |
_tied_weights_keys = ["classifier.emb2vocab.weight"]
|
| 773 |
|
| 774 |
def __init__(self, config: GptBertConfig, **kwargs):
|
| 775 |
+
self.is_decoder = True
|
| 776 |
super().__init__(config, add_mlm_layer=True, **kwargs)
|
| 777 |
|
| 778 |
def get_output_embeddings(self):
|
|
|
|
| 889 |
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 890 |
|
| 891 |
def __init__(self, config: GptBertConfig, **kwargs):
|
| 892 |
+
self.is_decoder = False
|
| 893 |
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 894 |
|
| 895 |
self.num_labels = config.num_labels
|
|
|
|
| 945 |
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 946 |
|
| 947 |
def __init__(self, config: GptBertConfig, **kwargs):
|
| 948 |
+
self.is_decoder = False
|
| 949 |
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 950 |
|
| 951 |
self.num_labels = config.num_labels
|
|
|
|
| 983 |
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 984 |
|
| 985 |
def __init__(self, config: GptBertConfig, **kwargs):
|
| 986 |
+
self.is_decoder = False
|
| 987 |
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 988 |
|
| 989 |
self.num_labels = config.num_labels
|
|
|
|
| 1040 |
_keys_to_ignore_on_load_unexpected = ["classifier.emb2vocab.weight", "classifier.emb2vocab.bias"]
|
| 1041 |
|
| 1042 |
def __init__(self, config: GptBertConfig, **kwargs):
|
| 1043 |
+
self.is_decoder = False
|
| 1044 |
super().__init__(config, add_mlm_layer=False, **kwargs)
|
| 1045 |
|
| 1046 |
self.num_labels = getattr(config, "num_labels", 2)
|