| from typing import Tuple, List, Optional | |
| from transformers import PretrainedConfig | |
| class PSIConfig(PretrainedConfig): | |
| model_type: str = "PSI" | |
| def __init__(self, | |
| vocab_size: int = 96256, | |
| channel_size: int = 12, | |
| n_layer: int = 12, | |
| n_head: int = 12, | |
| n_embd: int = 768, | |
| dropout: float = 0.0, | |
| bias: bool = False, | |
| attention_mask: str = "causal", | |
| tie_weights: bool = False, | |
| partition_embedding: bool = False, | |
| n_lm_vocab: Optional[int] = None, | |
| **kwargs | |
| ): | |
| self.vocab_size = vocab_size | |
| self.channel_size = channel_size | |
| self.n_layer = n_layer | |
| self.n_head = n_head | |
| self.n_embd = n_embd | |
| self.dropout = dropout | |
| self.bias = bias | |
| self.attention_mask = attention_mask | |
| self.tie_weights = tie_weights | |
| self.partition_embedding = partition_embedding | |
| self.n_lm_vocab = n_lm_vocab | |
| # Aside from HuggingFace default config attributes, | |
| # all extra kwargs are assigned using setattr. For HuggingFace attrs, see: | |
| # https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/configuration_utils.py#L45 | |
| # Since token ranges are checkpoint-specific, we don't include them | |
| # in this config and let them be assigned from kwargs. | |
| super().__init__(**kwargs) | |