File size: 1,519 Bytes
2eddd0d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | """FlashPPI model configuration."""
from transformers import PretrainedConfig
class FlashPPIConfig(PretrainedConfig):
model_type = "flashppi"
def __init__(
self,
# gLM2 backbone config (defaults match gLM2_650M)
plm_dim: int = 1280,
plm_depth: int = 33,
plm_heads: int = 20,
plm_vocab_size: int = 37,
plm_norm_eps: float = 1e-5,
plm_swiglu_multiple_of: int = 256,
plm_ffn_dim_multiplier: float = None,
# FlashPPI head config
clip_embed_dim: int = 1024,
contact_embed_dim: int = 1280,
contact_num_heads: int = 8,
contact_transformer_depth: int = 2,
max_position_embeddings: int = 512,
use_flash_attention: bool = True,
**kwargs
):
super().__init__(**kwargs)
# gLM2 config
self.plm_dim = plm_dim
self.plm_depth = plm_depth
self.plm_heads = plm_heads
self.plm_vocab_size = plm_vocab_size
self.plm_norm_eps = plm_norm_eps
self.plm_swiglu_multiple_of = plm_swiglu_multiple_of
self.plm_ffn_dim_multiplier = plm_ffn_dim_multiplier
# FlashPPI config
self.clip_embed_dim = clip_embed_dim
self.contact_embed_dim = contact_embed_dim
self.contact_num_heads = contact_num_heads
self.contact_transformer_depth = contact_transformer_depth
self.max_position_embeddings = max_position_embeddings
self.use_flash_attention = use_flash_attention
|