Jiqing commited on
Commit
c365ca1
·
verified ·
1 Parent(s): 09326a0

Create configuration_protst.py

Browse files
Files changed (1) hide show
  1. configuration_protst.py +51 -0
configuration_protst.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers.utils import logging
3
+ from transformers.models.esm import EsmConfig
4
+ from transformers.models.bert import BertConfig
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class ProtSTConfig(PretrainedConfig):
10
+ r"""
11
+ This is the configuration class to store the configuration of a [`ProtSTModel`].
12
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
13
+ documentation from [`PretrainedConfig`] for more information.
14
+ Args:
15
+ protein_config (`dict`, *optional*):
16
+ Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
17
+ text_config (`dict`, *optional*):
18
+ Dictionary of configuration options used to initialize [`BertForPubMed`].
19
+ ```"""
20
+
21
+ model_type = "protst"
22
+
23
+ def __init__(
24
+ self,
25
+ protein_config=None,
26
+ text_config=None,
27
+ **kwargs,
28
+ ):
29
+ super().__init__(**kwargs)
30
+
31
+ if protein_config is None:
32
+ protein_config = {}
33
+ logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.")
34
+
35
+ if text_config is None:
36
+ text_config = {}
37
+ logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.")
38
+
39
+ self.protein_config = EsmConfig(**protein_config)
40
+ self.text_config = BertConfig(**text_config)
41
+
42
+ @classmethod
43
+ def from_protein_text_configs(
44
+ cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs
45
+ ):
46
+ r"""
47
+ Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
48
+ [`ProtSTConfig`]: An instance of a configuration object
49
+ """
50
+
51
+ return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs)