File size: 7,348 Bytes
e093a4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
"""
Configuration class for TXModel compatible with HuggingFace Transformers
"""

from transformers import PretrainedConfig
from typing import Optional, Dict, Any


class TXConfig(PretrainedConfig):
    """
    Configuration class for TXModel.
    
    This class stores the configuration of a TXModel, which is a Transformer-based model
    for genomic/biological sequence analysis.
    
    Args:
        vocab_size (int): Size of the vocabulary
        d_model (int): Dimensionality of the model embeddings
        n_layers (int): Number of transformer layers
        n_heads (int): Number of attention heads
        expansion_ratio (int): Expansion ratio for FFN
        norm_scheme (str): Normalization scheme ('pre' or 'post')
        transformer_activation (str): Activation function for transformer
        cell_emb_style (str): Cell embedding style ('cls', 'avg-pool', 'w-pool')
        pad_token_id (int): ID of the padding token
        pad_value (float): Value for padding
        num_bins (int): Number of bins for expression values
        use_chem_token (bool): Whether to use chemical token encoder
        attn_config (Dict): Attention configuration
        norm_config (Dict): Normalization configuration
        init_config (Dict): Initialization configuration
        gene_encoder_config (Dict): Gene encoder configuration
        expression_encoder_config (Dict): Expression encoder configuration
        expression_decoder_config (Dict): Expression decoder configuration
        mvc_config (Optional[Dict]): MVC decoder configuration
        chemical_encoder_config (Optional[Dict]): Chemical encoder configuration
        use_glu (bool): Whether to use GLU in FFN
        return_gene_embeddings (bool): Whether to return gene embeddings
        standard_scale_outputs (bool): Whether to scale outputs
    """
    
    model_type = "tx_model"
    
    def __init__(
        self,
        vocab_size: int = 30000,
        d_model: int = 512,
        n_layers: int = 12,
        n_heads: int = 8,
        expansion_ratio: int = 4,
        norm_scheme: str = "pre",
        transformer_activation: str = "gelu",
        cell_emb_style: str = "cls",
        pad_token_id: int = 0,
        pad_value: float = 0.0,
        num_bins: int = 51,
        use_chem_token: bool = False,
        attn_config: Optional[Dict] = None,
        norm_config: Optional[Dict] = None,
        init_config: Optional[Dict] = None,
        gene_encoder_config: Optional[Dict] = None,
        expression_encoder_config: Optional[Dict] = None,
        expression_decoder_config: Optional[Dict] = None,
        mvc_config: Optional[Dict] = None,
        chemical_encoder_config: Optional[Dict] = None,
        use_glu: bool = False,
        return_gene_embeddings: bool = False,
        standard_scale_outputs: bool = False,
        keep_first_n_tokens: int = 1,
        **kwargs
    ):
        super().__init__(pad_token_id=pad_token_id, **kwargs)
        
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.expansion_ratio = expansion_ratio
        self.norm_scheme = norm_scheme
        self.transformer_activation = transformer_activation
        self.cell_emb_style = cell_emb_style
        self.pad_value = pad_value
        self.num_bins = num_bins
        self.use_chem_token = use_chem_token
        self.keep_first_n_tokens = keep_first_n_tokens
        self.return_gene_embeddings = return_gene_embeddings
        self.standard_scale_outputs = standard_scale_outputs
        self.use_glu = use_glu
        
        # Sub-configurations
        self.attn_config = attn_config or {
            "attn_type": "grouped_query_attention",
            "attn_pdrop": 0.0,
            "attn_impl": "flash",
            "use_attn_mask": False,
            "qk_ln": False,
            "qk_gn": False,
            "clip_qkv": None,
            "softmax_scale": None,
        }
        
        self.norm_config = norm_config or {
            "norm_type": "low_precision_layernorm",
            "eps": 1e-5,
        }
        
        self.init_config = init_config or {
            "name": "kaiming_normal_",
            "fan_mode": "fan_in",
            "init_nonlinearity": "relu",
            "init_div_is_residual": True,
            "emb_init_std": None,
            "emb_init_uniform_lim": None,
            "init_std": None,
            "init_gain": 0.0,
        }
        
        self.gene_encoder_config = gene_encoder_config or {
            "use_norm": False,
        }
        
        self.expression_encoder_config = expression_encoder_config or {
            "input_emb_style": "continuous",
            "dropout": 0.1,
            "max_value": 512,
            "activation": "relu",
            "use_norm": False,
        }
        
        self.expression_decoder_config = expression_decoder_config or {
            "n_outputs": 1,
            "n_layers": 2,
            "activation": "leaky_relu",
        }
        
        self.mvc_config = mvc_config
        self.chemical_encoder_config = chemical_encoder_config
    
    @classmethod
    def from_yaml_configs(cls, model_config_dict: Dict, collator_config_dict: Dict) -> "TXConfig":
        """
        Create TXConfig from model_config.yml and collator_config.yml dictionaries
        
        Args:
            model_config_dict: Dictionary from model_config.yml
            collator_config_dict: Dictionary from collator_config.yml
            
        Returns:
            TXConfig instance
        """
        return cls(
            vocab_size=model_config_dict.get("vocab_size"),
            d_model=model_config_dict.get("d_model"),
            n_layers=model_config_dict.get("n_layers"),
            n_heads=model_config_dict.get("n_heads"),
            expansion_ratio=model_config_dict.get("expansion_ratio"),
            norm_scheme=model_config_dict.get("norm_scheme", "pre"),
            transformer_activation=model_config_dict.get("transformer_activation", "gelu"),
            cell_emb_style=model_config_dict.get("cell_emb_style", "cls"),
            pad_token_id=collator_config_dict.get("pad_token_id", 0),
            pad_value=collator_config_dict.get("pad_value", 0.0),
            num_bins=collator_config_dict.get("num_bins", 51),
            use_chem_token=collator_config_dict.get("use_chem_token", False),
            attn_config=model_config_dict.get("attn_config"),
            norm_config=model_config_dict.get("norm_config"),
            init_config=model_config_dict.get("init_config"),
            gene_encoder_config=model_config_dict.get("gene_encoder"),
            expression_encoder_config=model_config_dict.get("expression_encoder"),
            expression_decoder_config=model_config_dict.get("expression_decoder"),
            mvc_config=model_config_dict.get("mvc"),
            chemical_encoder_config=model_config_dict.get("chemical_encoder"),
            use_glu=model_config_dict.get("use_glu", False),
            return_gene_embeddings=model_config_dict.get("return_gene_embeddings", False),
            standard_scale_outputs=model_config_dict.get("standard_scale_outputs", False),
            keep_first_n_tokens=collator_config_dict.get("keep_first_n_tokens", 1),
        )