Raziel1234 commited on
Commit
ef1ba69
verified
1 Parent(s): c3ff8f6

Update modeling_ostlm.py

Browse files
Files changed (1) hide show
  1. modeling_ostlm.py +11 -14
modeling_ostlm.py CHANGED
@@ -51,7 +51,7 @@ class OSTLMConfig(PretrainedConfig):
51
  class OSTLMModel(PreTrainedModel, GenerationMixin):
52
  config_class = OSTLMConfig
53
 
54
- # 转讬拽 砖讙讬讗讛: 讛讙讚专转 讛诪驻转讞讜转 讛拽砖讜专讬诐 注讘讜专 讛-Transformers Library
55
  _tied_weights_keys = ["lm_head.weight"]
56
 
57
  def __init__(self, config):
@@ -69,15 +69,14 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
69
  batch_first=True
70
  )
71
 
72
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
 
73
 
74
- # 拽砖讬专转 诪砖拽讜诇讜转 (Weight Tying)
75
  self.lm_head.weight = self.embedding.weight
76
 
77
- # 讗转讞讜诇 诪砖拽讜诇讜转 讜住讬谞讻专讜谉
78
  self.post_init()
79
 
80
- # 讛讻专讞讬 注讘讜专 Weight Tying 讘-AutoModel
81
  def get_output_embeddings(self):
82
  return self.lm_head
83
 
@@ -90,7 +89,9 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
90
  super().__init__()
91
  self.outer = outer
92
  def forward(self, input_ids=None, **kwargs):
93
- # 转诪讬讻讛 讘-padding 拽爪专 诪讛-max_pos
 
 
94
  seq_len = input_ids.size(1)
95
  out = self.outer.embedding(input_ids) + self.outer.pos_emb[:, :seq_len, :]
96
  return BaseModelOutput(last_hidden_state=out)
@@ -104,7 +105,7 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
104
  labels=None,
105
  **kwargs
106
  ):
107
- # 1. 讟讬驻讜诇 讘讗谞拽讜讚专 (Encoder)
108
  if encoder_outputs is not None:
109
  if isinstance(encoder_outputs, (tuple, list)):
110
  src_emb = encoder_outputs[0]
@@ -115,11 +116,10 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
115
  else:
116
  src_emb = self.embedding(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
117
 
118
- # 2. 讟讬驻讜诇 讘讚讬拽讜讚专 (Decoder Input)
119
  if decoder_input_ids is None:
120
  decoder_input_ids = kwargs.get("input_ids")
121
  if decoder_input_ids is None:
122
- # 讗转讞讜诇 讘专讬专转 诪讞讚诇 诇-Generation
123
  decoder_input_ids = torch.full(
124
  (src_emb.size(0), 1),
125
  self.config.decoder_start_token_id,
@@ -127,15 +127,11 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
127
  ).to(self.device)
128
 
129
  tgt_emb = self.embedding(decoder_input_ids) + self.pos_emb[:, :decoder_input_ids.size(1), :]
130
-
131
- # 讬爪讬专转 诪住讬讻讛 住讬讘转讬讜转 (Causal Mask)
132
  tgt_mask = self.transformer.generate_square_subsequent_mask(decoder_input_ids.size(1)).to(self.device)
133
 
134
- # 3. 讛专爪讛 讘讟专谞住驻讜专诪专
135
  out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
136
  logits = self.lm_head(out)
137
 
138
- # 4. 讞讬砖讜讘 Loss 讘诪讬讚讛 讜讬砖 诇讬讬讘诇讬诐
139
  loss = None
140
  if labels is not None:
141
  loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
@@ -148,11 +144,12 @@ class OSTLMModel(PreTrainedModel, GenerationMixin):
148
  )
149
 
150
  def prepare_inputs_for_generation(self, input_ids, encoder_outputs=None, **kwargs):
 
151
  return {
152
  "decoder_input_ids": input_ids,
153
  "encoder_outputs": encoder_outputs,
154
  }
155
 
156
- # 专讬砖讜诐 注讘专 砖诪讜砖 讘-AutoModel
157
  AutoConfig.register("ostlm", OSTLMConfig)
158
  AutoModelForSeq2SeqLM.register(OSTLMConfig, OSTLMModel)
 
51
  class OSTLMModel(PreTrainedModel, GenerationMixin):
52
  config_class = OSTLMConfig
53
 
54
+ # 转专 讗转 砖讙讬讗转 讛-AttributeError
55
  _tied_weights_keys = ["lm_head.weight"]
56
 
57
  def __init__(self, config):
 
69
  batch_first=True
70
  )
71
 
72
+ # 砖讬谞讜讬 诇-bias=True 讻讚讬 诇讛转讗讬诐 诇诪砖拽讜诇讜转 讛拽讬讬诪讜转 砖诇讱
73
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=True)
74
 
75
+ # 拽砖讬专转 诪砖拽讜诇讜转
76
  self.lm_head.weight = self.embedding.weight
77
 
 
78
  self.post_init()
79
 
 
80
  def get_output_embeddings(self):
81
  return self.lm_head
82
 
 
89
  super().__init__()
90
  self.outer = outer
91
  def forward(self, input_ids=None, **kwargs):
92
+ if input_ids is None:
93
+ # 讛讙谞讛 诇诪拽专讛 砖-generate 砖讜诇讞 inputs 讗讞专讬诐
94
+ input_ids = kwargs.get("decoder_input_ids")
95
  seq_len = input_ids.size(1)
96
  out = self.outer.embedding(input_ids) + self.outer.pos_emb[:, :seq_len, :]
97
  return BaseModelOutput(last_hidden_state=out)
 
105
  labels=None,
106
  **kwargs
107
  ):
108
+ # 讟讬驻讜诇 讘讗谞拽讜讚专
109
  if encoder_outputs is not None:
110
  if isinstance(encoder_outputs, (tuple, list)):
111
  src_emb = encoder_outputs[0]
 
116
  else:
117
  src_emb = self.embedding(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
118
 
119
+ # 讟讬驻讜诇 讘讚讬拽讜讚专 - 讜讬讚讜讗 砖讗讬谉 讻驻讬诇讜转 驻专诪讟专讬诐 诪-generate
120
  if decoder_input_ids is None:
121
  decoder_input_ids = kwargs.get("input_ids")
122
  if decoder_input_ids is None:
 
123
  decoder_input_ids = torch.full(
124
  (src_emb.size(0), 1),
125
  self.config.decoder_start_token_id,
 
127
  ).to(self.device)
128
 
129
  tgt_emb = self.embedding(decoder_input_ids) + self.pos_emb[:, :decoder_input_ids.size(1), :]
 
 
130
  tgt_mask = self.transformer.generate_square_subsequent_mask(decoder_input_ids.size(1)).to(self.device)
131
 
 
132
  out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
133
  logits = self.lm_head(out)
134
 
 
135
  loss = None
136
  if labels is not None:
137
  loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
 
144
  )
145
 
146
  def prepare_inputs_for_generation(self, input_ids, encoder_outputs=None, **kwargs):
147
+ # 驻讜转专 讗转 讘注讬讬转 讛-model_kwargs 讛诇讗 诪砖讜诪砖讬诐
148
  return {
149
  "decoder_input_ids": input_ids,
150
  "encoder_outputs": encoder_outputs,
151
  }
152
 
153
+ # 专讬砖讜诐
154
  AutoConfig.register("ostlm", OSTLMConfig)
155
  AutoModelForSeq2SeqLM.register(OSTLMConfig, OSTLMModel)