if001 commited on
Commit
1b9a33d
·
verified ·
1 Parent(s): 9374ea2

Update modeling_residualnet.py

Browse files
Files changed (1) hide show
  1. modeling_residualnet.py +19 -1
modeling_residualnet.py CHANGED
@@ -342,11 +342,29 @@ class ResidualNetForCausalLM(Phi3PreTrainedModel, GenerationMixin):
342
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
343
 
344
  # weight tying
345
- self.lm_head.weight = self.model.embed_tokens.weight
346
 
347
  # Initialize weights and apply final processing
348
  self.post_init()
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def forward(
351
  self,
352
  input_ids: Optional[torch.LongTensor] = None,
 
342
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
343
 
344
  # weight tying
345
+ # self.lm_head.weight = self.model.embed_tokens.weight
346
 
347
  # Initialize weights and apply final processing
348
  self.post_init()
349
 
350
+ def get_input_embeddings(self):
351
+ return self.model.embed_tokens
352
+
353
+ def set_input_embeddings(self, value):
354
+ self.model.embed_tokens = value
355
+
356
+ def get_output_embeddings(self):
357
+ return self.lm_head
358
+
359
+ def set_output_embeddings(self, new_embeddings):
360
+ self.lm_head = new_embeddings
361
+
362
+ def set_decoder(self, decoder):
363
+ self.model = decoder
364
+
365
+ def get_decoder(self):
366
+ return self.model
367
+
368
  def forward(
369
  self,
370
  input_ids: Optional[torch.LongTensor] = None,