JonusNattapong commited on
Commit
e5dbf14
·
verified ·
1 Parent(s): ca2d027

End of training

Browse files
Files changed (1) hide show
  1. modeling_gptoss_mini.py +4 -0
modeling_gptoss_mini.py CHANGED
@@ -98,6 +98,9 @@ class Block(nn.Module):
98
  class GPTMiniForCausalLM(PreTrainedModel, GenerationMixin):
99
  config_class = GPTMiniConfig
100
 
 
 
 
101
  def __init__(self, config: GPTMiniConfig):
102
  super().__init__(config)
103
  self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
@@ -106,6 +109,7 @@ class GPTMiniForCausalLM(PreTrainedModel, GenerationMixin):
106
  self.ln_f = RMSNorm(config.hidden_size)
107
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
108
 
 
109
  self.post_init()
110
 
111
  def get_input_embeddings(self):
 
98
  class GPTMiniForCausalLM(PreTrainedModel, GenerationMixin):
99
  config_class = GPTMiniConfig
100
 
101
+ _keys_to_ignore_on_save = []
102
+ _dynamic_tied_weights_keys = {"lm_head.weight", "embed.weight"}
103
+
104
  def __init__(self, config: GPTMiniConfig):
105
  super().__init__(config)
106
  self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
 
109
  self.ln_f = RMSNorm(config.hidden_size)
110
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
111
 
112
+ self.lm_head.weight = self.embed.weight
113
  self.post_init()
114
 
115
  def get_input_embeddings(self):