File size: 906 Bytes
55949a6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | from transformers import PretrainedConfig
class DallEConfig(PretrainedConfig):
def __init__(
self,
group_count: int = 4,
n_hid: int = 256,
n_blk_per_group: int = 2,
input_channels: int = 3,
vocab_size: int = 8192,
device: str = 'cpu',
requires_grad: bool = False,
use_mixed_precision: bool = True,
**kwargs,
):
super().__init__(**kwargs)
assert input_channels >= 1
assert n_hid >= 64
assert n_blk_per_group >= 1
assert vocab_size >= 512
self.group_count = group_count
self.n_hid = n_hid
self.n_blk_per_group = n_blk_per_group
self.input_channels = input_channels
self.vocab_size = vocab_size
self.device = device
self.requires_grad = requires_grad
self.use_mixed_precision = use_mixed_precision |