| """FlashPPI model configuration.""" |
|
|
| from transformers import PretrainedConfig |
|
|
|
|
| class FlashPPIConfig(PretrainedConfig): |
| |
| model_type = "flashppi" |
| |
| def __init__( |
| self, |
| |
| plm_dim: int = 1280, |
| plm_depth: int = 33, |
| plm_heads: int = 20, |
| plm_vocab_size: int = 37, |
| plm_norm_eps: float = 1e-5, |
| plm_swiglu_multiple_of: int = 256, |
| plm_ffn_dim_multiplier: float = None, |
| |
| clip_embed_dim: int = 1024, |
| contact_embed_dim: int = 1280, |
| contact_num_heads: int = 8, |
| contact_transformer_depth: int = 2, |
| max_position_embeddings: int = 512, |
| use_flash_attention: bool = True, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| |
| self.plm_dim = plm_dim |
| self.plm_depth = plm_depth |
| self.plm_heads = plm_heads |
| self.plm_vocab_size = plm_vocab_size |
| self.plm_norm_eps = plm_norm_eps |
| self.plm_swiglu_multiple_of = plm_swiglu_multiple_of |
| self.plm_ffn_dim_multiplier = plm_ffn_dim_multiplier |
| |
| self.clip_embed_dim = clip_embed_dim |
| self.contact_embed_dim = contact_embed_dim |
| self.contact_num_heads = contact_num_heads |
| self.contact_transformer_depth = contact_transformer_depth |
| self.max_position_embeddings = max_position_embeddings |
| self.use_flash_attention = use_flash_attention |
|
|