jinmang2 commited on
Commit
55949a6
·
1 Parent(s): 606396d

Create configuration_dalle.py

Browse files
Files changed (1) hide show
  1. configuration_dalle.py +31 -0
configuration_dalle.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class DallEConfig(PretrainedConfig):
5
+ def __init__(
6
+ self,
7
+ group_count: int = 4,
8
+ n_hid: int = 256,
9
+ n_blk_per_group: int = 2,
10
+ input_channels: int = 3,
11
+ vocab_size: int = 8192,
12
+ device: str = 'cpu',
13
+ requires_grad: bool = False,
14
+ use_mixed_precision: bool = True,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+
19
+ assert input_channels >= 1
20
+ assert n_hid >= 64
21
+ assert n_blk_per_group >= 1
22
+ assert vocab_size >= 512
23
+
24
+ self.group_count = group_count
25
+ self.n_hid = n_hid
26
+ self.n_blk_per_group = n_blk_per_group
27
+ self.input_channels = input_channels
28
+ self.vocab_size = vocab_size
29
+ self.device = device
30
+ self.requires_grad = requires_grad
31
+ self.use_mixed_precision = use_mixed_precision