Upload modeling.py
Browse files- modeling.py +17 -1
modeling.py
CHANGED
|
@@ -180,6 +180,7 @@ class MoEGPTForCausalLM(PreTrainedModel):
|
|
| 180 |
assert config.sequence_length is not None
|
| 181 |
self.config = config
|
| 182 |
self.tokenizer = tiktoken.get_encoding("gpt2")
|
|
|
|
| 183 |
|
| 184 |
self.transformer = nn.ModuleDict(
|
| 185 |
dict(
|
|
@@ -260,7 +261,7 @@ class MoEGPTForCausalLM(PreTrainedModel):
|
|
| 260 |
elif isinstance(module, nn.Embedding):
|
| 261 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 262 |
|
| 263 |
-
def forward(self, idx, date=None, targets=None, get_logits=True, moe=False):
|
| 264 |
device = idx.device
|
| 265 |
b, t = idx.size()
|
| 266 |
assert (
|
|
@@ -463,3 +464,18 @@ class MoEGPTForCausalLM(PreTrainedModel):
|
|
| 463 |
.numpy()
|
| 464 |
)
|
| 465 |
return self.tokenizer.decode(out_idx).split(in_str)[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
assert config.sequence_length is not None
|
| 181 |
self.config = config
|
| 182 |
self.tokenizer = tiktoken.get_encoding("gpt2")
|
| 183 |
+
self.base_model_prefix = "timoe"
|
| 184 |
|
| 185 |
self.transformer = nn.ModuleDict(
|
| 186 |
dict(
|
|
|
|
| 261 |
elif isinstance(module, nn.Embedding):
|
| 262 |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 263 |
|
| 264 |
+
def forward(self, idx, date=None, targets=None, attention_mask=None, get_logits=True, moe=False):
|
| 265 |
device = idx.device
|
| 266 |
b, t = idx.size()
|
| 267 |
assert (
|
|
|
|
| 464 |
.numpy()
|
| 465 |
)
|
| 466 |
return self.tokenizer.decode(out_idx).split(in_str)[-1]
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def get_input_embeddings(self):
|
| 470 |
+
return self.transformer.wte
|
| 471 |
+
|
| 472 |
+
def set_input_embeddings(self, new_embeddings):
|
| 473 |
+
self.transformer.wte = new_embeddings
|
| 474 |
+
# reset the lm_head to use the new embeddings
|
| 475 |
+
# this is necessary because the lm_head is tied to the input embeddings
|
| 476 |
+
self.lm_head = nn.Linear(
|
| 477 |
+
self.config.n_embd, new_embeddings.weight.shape[0] , bias=False
|
| 478 |
+
)
|
| 479 |
+
#self.transformer.wte.weight = (
|
| 480 |
+
# self.lm_head.weight
|
| 481 |
+
#)
|