robinfaro commited on
Commit
6a3cafd
·
verified ·
1 Parent(s): 0f280d8

Upload modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +17 -1
modeling.py CHANGED
@@ -180,6 +180,7 @@ class MoEGPTForCausalLM(PreTrainedModel):
180
  assert config.sequence_length is not None
181
  self.config = config
182
  self.tokenizer = tiktoken.get_encoding("gpt2")
 
183
 
184
  self.transformer = nn.ModuleDict(
185
  dict(
@@ -260,7 +261,7 @@ class MoEGPTForCausalLM(PreTrainedModel):
260
  elif isinstance(module, nn.Embedding):
261
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
262
 
263
- def forward(self, idx, date=None, targets=None, get_logits=True, moe=False):
264
  device = idx.device
265
  b, t = idx.size()
266
  assert (
@@ -463,3 +464,18 @@ class MoEGPTForCausalLM(PreTrainedModel):
463
  .numpy()
464
  )
465
  return self.tokenizer.decode(out_idx).split(in_str)[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  assert config.sequence_length is not None
181
  self.config = config
182
  self.tokenizer = tiktoken.get_encoding("gpt2")
183
+ self.base_model_prefix = "timoe"
184
 
185
  self.transformer = nn.ModuleDict(
186
  dict(
 
261
  elif isinstance(module, nn.Embedding):
262
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
263
 
264
+ def forward(self, idx, date=None, targets=None, attention_mask=None, get_logits=True, moe=False):
265
  device = idx.device
266
  b, t = idx.size()
267
  assert (
 
464
  .numpy()
465
  )
466
  return self.tokenizer.decode(out_idx).split(in_str)[-1]
467
+
468
+
469
+ def get_input_embeddings(self):
470
+ return self.transformer.wte
471
+
472
+ def set_input_embeddings(self, new_embeddings):
473
+ self.transformer.wte = new_embeddings
474
+ # reset the lm_head to use the new embeddings
475
+ # this is necessary because the lm_head is tied to the input embeddings
476
+ self.lm_head = nn.Linear(
477
+ self.config.n_embd, new_embeddings.weight.shape[0] , bias=False
478
+ )
479
+ #self.transformer.wte.weight = (
480
+ # self.lm_head.weight
481
+ #)