prism-v2-wikitext / configuration_prism.py
prism-lab's picture
Initial anonymous commit
1051294 verified
raw
history blame contribute delete
595 Bytes
from transformers import PretrainedConfig
class PRISMConfig(PretrainedConfig):
model_type = "prism-v2"
def __init__(
self,
vocab_size=32768,
d_model=512,
seq_len=4096,
prism_depth=5,
trans_depth=1,
fft_dim=64,
dropout=0.1,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.seq_len = seq_len
self.prism_depth = prism_depth
self.trans_depth = trans_depth
self.fft_dim = fft_dim
self.dropout = dropout