PSI / config.py
TheTrueJard's picture
Upload folder using huggingface_hub
6f09125 verified
from typing import Tuple, List, Optional
from transformers import PretrainedConfig
class PSIConfig(PretrainedConfig):
model_type: str = "PSI"
def __init__(self,
vocab_size: int = 96256,
channel_size: int = 12,
n_layer: int = 12,
n_head: int = 12,
n_embd: int = 768,
dropout: float = 0.0,
bias: bool = False,
attention_mask: str = "causal",
tie_weights: bool = False,
partition_embedding: bool = False,
n_lm_vocab: Optional[int] = None,
**kwargs
):
self.vocab_size = vocab_size
self.channel_size = channel_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.dropout = dropout
self.bias = bias
self.attention_mask = attention_mask
self.tie_weights = tie_weights
self.partition_embedding = partition_embedding
self.n_lm_vocab = n_lm_vocab
# Aside from HuggingFace default config attributes,
# all extra kwargs are assigned using setattr. For HuggingFace attrs, see:
# https://github.com/huggingface/transformers/blob/v4.53.3/src/transformers/configuration_utils.py#L45
# Since token ranges are checkpoint-specific, we don't include them
# in this config and let them be assigned from kwargs.
super().__init__(**kwargs)