dev-das commited on
Commit
589a4c0
·
verified ·
1 Parent(s): ee7b825

Update modeling_my_gpt.py

Browse files
Files changed (1) hide show
  1. modeling_my_gpt.py +5 -6
modeling_my_gpt.py CHANGED
@@ -2,25 +2,24 @@ import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
  from transformers.modeling_outputs import CausalLMOutput
 
 
5
 
6
  from .configuration_my_gpt import MyGPTConfig
7
  from .untrained_model import GPTModel
8
 
9
  import os
10
- import sys
11
-
12
- curr_dir = os.getcwd()
13
- parent_dir = os.path.dirname(curr_dir)
14
 
15
- sys.path.insert(0, parent_dir)
16
 
17
 
18
- class MyGPTForCausalLM(PreTrainedModel):
19
  config_class = MyGPTConfig
 
20
 
21
  def __init__(self, config):
22
  super().__init__(config)
23
 
 
24
  # Import your original GPTModel
25
  self.model = GPTModel({
26
  "vocab_size": config.vocab_size,
 
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel
4
  from transformers.modeling_outputs import CausalLMOutput
5
+ from transformers.generation import GenerationMixin
6
+
7
 
8
  from .configuration_my_gpt import MyGPTConfig
9
  from .untrained_model import GPTModel
10
 
11
  import os
 
 
 
 
12
 
 
13
 
14
 
15
+ class MyGPTForCausalLM(PreTrainedModel, GenerationMixin):
16
  config_class = MyGPTConfig
17
+ main_input_name = "input_ids"
18
 
19
  def __init__(self, config):
20
  super().__init__(config)
21
 
22
+
23
  # Import your original GPTModel
24
  self.model = GPTModel({
25
  "vocab_size": config.vocab_size,