| from transformers import PretrainedConfig | |
| class PRISMConfig(PretrainedConfig): | |
| model_type = "prism-v2" | |
| def __init__( | |
| self, | |
| vocab_size=32768, | |
| d_model=512, | |
| seq_len=4096, | |
| prism_depth=5, | |
| trans_depth=1, | |
| fft_dim=64, | |
| dropout=0.1, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.d_model = d_model | |
| self.seq_len = seq_len | |
| self.prism_depth = prism_depth | |
| self.trans_depth = trans_depth | |
| self.fft_dim = fft_dim | |
| self.dropout = dropout | |