File size: 331 Bytes
fe668e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers import PretrainedConfig


class GPJTGPT2Config(PretrainedConfig):

    model_type = "gpjtgpt2"

    def __init__(self, cfg=None, **kwargs):
        self.cfg = cfg

        if cfg is not None:
            self.num_hidden_layers = cfg["n_layers"]

        super().__init__(**kwargs)

        self.use_cache = False