Update modeling_ostlm.py
Browse files- modeling_ostlm.py +11 -14
modeling_ostlm.py
CHANGED
|
@@ -51,7 +51,7 @@ class OSTLMConfig(PretrainedConfig):
|
|
| 51 |
class OSTLMModel(PreTrainedModel, GenerationMixin):
|
| 52 |
config_class = OSTLMConfig
|
| 53 |
|
| 54 |
-
#
|
| 55 |
_tied_weights_keys = ["lm_head.weight"]
|
| 56 |
|
| 57 |
def __init__(self, config):
|
|
@@ -69,15 +69,14 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
|
|
| 69 |
batch_first=True
|
| 70 |
)
|
| 71 |
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
-
# 拽砖讬专转 诪砖拽讜诇讜转
|
| 75 |
self.lm_head.weight = self.embedding.weight
|
| 76 |
|
| 77 |
-
# 讗转讞讜诇 诪砖拽讜诇讜转 讜住讬谞讻专讜谉
|
| 78 |
self.post_init()
|
| 79 |
|
| 80 |
-
# 讛讻专讞讬 注讘讜专 Weight Tying 讘-AutoModel
|
| 81 |
def get_output_embeddings(self):
|
| 82 |
return self.lm_head
|
| 83 |
|
|
@@ -90,7 +89,9 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
|
|
| 90 |
super().__init__()
|
| 91 |
self.outer = outer
|
| 92 |
def forward(self, input_ids=None, **kwargs):
|
| 93 |
-
|
|
|
|
|
|
|
| 94 |
seq_len = input_ids.size(1)
|
| 95 |
out = self.outer.embedding(input_ids) + self.outer.pos_emb[:, :seq_len, :]
|
| 96 |
return BaseModelOutput(last_hidden_state=out)
|
|
@@ -104,7 +105,7 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
|
|
| 104 |
labels=None,
|
| 105 |
**kwargs
|
| 106 |
):
|
| 107 |
-
#
|
| 108 |
if encoder_outputs is not None:
|
| 109 |
if isinstance(encoder_outputs, (tuple, list)):
|
| 110 |
src_emb = encoder_outputs[0]
|
|
@@ -115,11 +116,10 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
|
|
| 115 |
else:
|
| 116 |
src_emb = self.embedding(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
|
| 117 |
|
| 118 |
-
#
|
| 119 |
if decoder_input_ids is None:
|
| 120 |
decoder_input_ids = kwargs.get("input_ids")
|
| 121 |
if decoder_input_ids is None:
|
| 122 |
-
# 讗转讞讜诇 讘专讬专转 诪讞讚诇 诇-Generation
|
| 123 |
decoder_input_ids = torch.full(
|
| 124 |
(src_emb.size(0), 1),
|
| 125 |
self.config.decoder_start_token_id,
|
|
@@ -127,15 +127,11 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
|
|
| 127 |
).to(self.device)
|
| 128 |
|
| 129 |
tgt_emb = self.embedding(decoder_input_ids) + self.pos_emb[:, :decoder_input_ids.size(1), :]
|
| 130 |
-
|
| 131 |
-
# 讬爪讬专转 诪住讬讻讛 住讬讘转讬讜转 (Causal Mask)
|
| 132 |
tgt_mask = self.transformer.generate_square_subsequent_mask(decoder_input_ids.size(1)).to(self.device)
|
| 133 |
|
| 134 |
-
# 3. 讛专爪讛 讘讟专谞住驻讜专诪专
|
| 135 |
out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
|
| 136 |
logits = self.lm_head(out)
|
| 137 |
|
| 138 |
-
# 4. 讞讬砖讜讘 Loss 讘诪讬讚讛 讜讬砖 诇讬讬讘诇讬诐
|
| 139 |
loss = None
|
| 140 |
if labels is not None:
|
| 141 |
loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
|
@@ -148,11 +144,12 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
|
|
| 148 |
)
|
| 149 |
|
| 150 |
def prepare_inputs_for_generation(self, input_ids, encoder_outputs=None, **kwargs):
|
|
|
|
| 151 |
return {
|
| 152 |
"decoder_input_ids": input_ids,
|
| 153 |
"encoder_outputs": encoder_outputs,
|
| 154 |
}
|
| 155 |
|
| 156 |
-
# 专讬砖讜诐
|
| 157 |
AutoConfig.register("ostlm", OSTLMConfig)
|
| 158 |
AutoModelForSeq2SeqLM.register(OSTLMConfig, OSTLMModel)
|
|
|
|
| 51 |
class OSTLMModel(PreTrainedModel, GenerationMixin):
|
| 52 |
config_class = OSTLMConfig
|
| 53 |
|
| 54 |
+
# 驻讜转专 讗转 砖讙讬讗转 讛-AttributeError
|
| 55 |
_tied_weights_keys = ["lm_head.weight"]
|
| 56 |
|
| 57 |
def __init__(self, config):
|
|
|
|
| 69 |
batch_first=True
|
| 70 |
)
|
| 71 |
|
| 72 |
+
# 砖讬谞讜讬 诇-bias=True 讻讚讬 诇讛转讗讬诐 诇诪砖拽讜诇讜转 讛拽讬讬诪讜转 砖诇讱
|
| 73 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=True)
|
| 74 |
|
| 75 |
+
# 拽砖讬专转 诪砖拽讜诇讜转
|
| 76 |
self.lm_head.weight = self.embedding.weight
|
| 77 |
|
|
|
|
| 78 |
self.post_init()
|
| 79 |
|
|
|
|
| 80 |
def get_output_embeddings(self):
|
| 81 |
return self.lm_head
|
| 82 |
|
|
|
|
| 89 |
super().__init__()
|
| 90 |
self.outer = outer
|
| 91 |
def forward(self, input_ids=None, **kwargs):
|
| 92 |
+
if input_ids is None:
|
| 93 |
+
# 讛讙谞讛 诇诪拽专讛 砖-generate 砖讜诇讞 inputs 讗讞专讬诐
|
| 94 |
+
input_ids = kwargs.get("decoder_input_ids")
|
| 95 |
seq_len = input_ids.size(1)
|
| 96 |
out = self.outer.embedding(input_ids) + self.outer.pos_emb[:, :seq_len, :]
|
| 97 |
return BaseModelOutput(last_hidden_state=out)
|
|
|
|
| 105 |
labels=None,
|
| 106 |
**kwargs
|
| 107 |
):
|
| 108 |
+
# 讟讬驻讜诇 讘讗谞拽讜讚专
|
| 109 |
if encoder_outputs is not None:
|
| 110 |
if isinstance(encoder_outputs, (tuple, list)):
|
| 111 |
src_emb = encoder_outputs[0]
|
|
|
|
| 116 |
else:
|
| 117 |
src_emb = self.embedding(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
|
| 118 |
|
| 119 |
+
# 讟讬驻讜诇 讘讚讬拽讜讚专 - 讜讬讚讜讗 砖讗讬谉 讻驻讬诇讜转 驻专诪讟专讬诐 诪-generate
|
| 120 |
if decoder_input_ids is None:
|
| 121 |
decoder_input_ids = kwargs.get("input_ids")
|
| 122 |
if decoder_input_ids is None:
|
|
|
|
| 123 |
decoder_input_ids = torch.full(
|
| 124 |
(src_emb.size(0), 1),
|
| 125 |
self.config.decoder_start_token_id,
|
|
|
|
| 127 |
).to(self.device)
|
| 128 |
|
| 129 |
tgt_emb = self.embedding(decoder_input_ids) + self.pos_emb[:, :decoder_input_ids.size(1), :]
|
|
|
|
|
|
|
| 130 |
tgt_mask = self.transformer.generate_square_subsequent_mask(decoder_input_ids.size(1)).to(self.device)
|
| 131 |
|
|
|
|
| 132 |
out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
|
| 133 |
logits = self.lm_head(out)
|
| 134 |
|
|
|
|
| 135 |
loss = None
|
| 136 |
if labels is not None:
|
| 137 |
loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
def prepare_inputs_for_generation(self, input_ids, encoder_outputs=None, **kwargs):
|
| 147 |
+
# 驻讜转专 讗转 讘注讬讬转 讛-model_kwargs 讛诇讗 诪砖讜诪砖讬诐
|
| 148 |
return {
|
| 149 |
"decoder_input_ids": input_ids,
|
| 150 |
"encoder_outputs": encoder_outputs,
|
| 151 |
}
|
| 152 |
|
| 153 |
+
# 专讬砖讜诐 住讜驻讬
|
| 154 |
AutoConfig.register("ostlm", OSTLMConfig)
|
| 155 |
AutoModelForSeq2SeqLM.register(OSTLMConfig, OSTLMModel)
|