| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
| class RITAConfig(PretrainedConfig): |
| model_type = "rita" |
|
|
| def __init__( |
| self, |
| vocab_size=26, |
| d_model=768, |
| num_layers=12, |
| max_seq_len=1024, |
| num_heads=12, |
| dropout=0., |
| ff_ratio=4, |
| eos_token_id=2, |
| initializer_range=0.02, |
| **kwargs, |
| ): |
| super().__init__(eos_token_id=eos_token_id, **kwargs) |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.num_heads = num_heads |
| self.d_feedforward = d_model*ff_ratio |
| self.num_layers = num_layers |
| self.max_seq_len=max_seq_len |
| self.dropout = dropout |
| self.eos_token_id=eos_token_id |
| self.initializer_range=0.02 |
|
|