nz commited on
Commit
d5e97d5
·
1 Parent(s): 4b4e0fa

Create rita_configuration.py

Browse files
Files changed (1) hide show
  1. rita_configuration.py +31 -0
rita_configuration.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+ class RITAConfig(PretrainedConfig):
7
+ model_type = "rita"
8
+
9
+ def __init__(
10
+ self,
11
+ in_vocab_size=128,
12
+ out_vocab_size=32,
13
+ d_model=768,
14
+ num_layers=12,
15
+ max_seq_len=1024,
16
+ num_heads=12,
17
+ dropout=0.,
18
+ ff_ratio=4,
19
+ eos_token_id=2,
20
+ **kwargs,
21
+ ):
22
+ super().__init__(eos_token_id=eos_token_id, **kwargs)
23
+ self.in_vocab_size = in_vocab_size
24
+ self.out_vocab_size = out_vocab_size
25
+ self.d_model = d_model
26
+ self.num_heads = num_heads
27
+ self.d_feedforward = d_model*ff_ratio
28
+ self.num_layers = num_layers
29
+ self.max_seq_len=max_seq_len
30
+ self.dropout = dropout
31
+ self.eos_token_id=eos_token_id