mRNA-FM / configuration_rnafm.py
Taykhoom's picture
Upload folder using huggingface_hub
7ba311b verified
from transformers import PretrainedConfig
class RnaFmConfig(PretrainedConfig):
model_type = "rnafm"
auto_map = {
"AutoConfig": "configuration_rnafm.RnaFmConfig",
"AutoModel": "modeling_rnafm.RnaFmModel",
"AutoModelForMaskedLM": "modeling_rnafm.RnaFmForMaskedLM",
"AutoTokenizer": ["tokenization_rnafm.RnaFmTokenizer", None],
}
def __init__(
self,
vocab_size: int = 25,
num_layers: int = 12,
embed_dim: int = 640,
ffn_embed_dim: int = 5120,
attention_heads: int = 20,
padding_idx: int = 1,
mask_idx: int = 24,
cls_idx: int = 0,
eos_idx: int = 2,
token_dropout: bool = False,
emb_layer_norm_before: bool = True,
model_max_length: int = 1024,
model_variant: str = "rna",
**kwargs,
):
super().__init__(padding_idx=padding_idx, **kwargs)
self.vocab_size = vocab_size
self.num_layers = num_layers
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self.mask_idx = mask_idx
self.cls_idx = cls_idx
self.eos_idx = eos_idx
self.token_dropout = token_dropout
self.emb_layer_norm_before = emb_layer_norm_before
self.model_max_length = model_max_length
self.model_variant = model_variant