File size: 331 Bytes
fe668e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | from transformers import PretrainedConfig
class GPJTGPT2Config(PretrainedConfig):
model_type = "gpjtgpt2"
def __init__(self, cfg=None, **kwargs):
self.cfg = cfg
if cfg is not None:
self.num_hidden_layers = cfg["n_layers"]
super().__init__(**kwargs)
self.use_cache = False
|