File size: 906 Bytes
55949a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import PretrainedConfig


class DallEConfig(PretrainedConfig):
    def __init__(
        self,
        group_count: int = 4,
        n_hid: int = 256,
        n_blk_per_group: int = 2,
        input_channels: int = 3,
        vocab_size: int = 8192,
        device: str = 'cpu',
        requires_grad: bool = False,
        use_mixed_precision: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        
        assert input_channels >= 1
        assert n_hid >= 64
        assert n_blk_per_group >= 1
        assert vocab_size >= 512
        
        self.group_count = group_count
        self.n_hid = n_hid
        self.n_blk_per_group = n_blk_per_group
        self.input_channels = input_channels
        self.vocab_size = vocab_size
        self.device = device
        self.requires_grad = requires_grad
        self.use_mixed_precision = use_mixed_precision