tx-model-standalone / configuration_tx.py
Yuto2007's picture
Upload folder using huggingface_hub
e093a4b verified
# Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
"""
Configuration class for TXModel compatible with HuggingFace Transformers
"""
from transformers import PretrainedConfig
from typing import Optional, Dict, Any
class TXConfig(PretrainedConfig):
"""
Configuration class for TXModel.
This class stores the configuration of a TXModel, which is a Transformer-based model
for genomic/biological sequence analysis.
Args:
vocab_size (int): Size of the vocabulary
d_model (int): Dimensionality of the model embeddings
n_layers (int): Number of transformer layers
n_heads (int): Number of attention heads
expansion_ratio (int): Expansion ratio for FFN
norm_scheme (str): Normalization scheme ('pre' or 'post')
transformer_activation (str): Activation function for transformer
cell_emb_style (str): Cell embedding style ('cls', 'avg-pool', 'w-pool')
pad_token_id (int): ID of the padding token
pad_value (float): Value for padding
num_bins (int): Number of bins for expression values
use_chem_token (bool): Whether to use chemical token encoder
attn_config (Dict): Attention configuration
norm_config (Dict): Normalization configuration
init_config (Dict): Initialization configuration
gene_encoder_config (Dict): Gene encoder configuration
expression_encoder_config (Dict): Expression encoder configuration
expression_decoder_config (Dict): Expression decoder configuration
mvc_config (Optional[Dict]): MVC decoder configuration
chemical_encoder_config (Optional[Dict]): Chemical encoder configuration
use_glu (bool): Whether to use GLU in FFN
return_gene_embeddings (bool): Whether to return gene embeddings
standard_scale_outputs (bool): Whether to scale outputs
"""
model_type = "tx_model"
def __init__(
self,
vocab_size: int = 30000,
d_model: int = 512,
n_layers: int = 12,
n_heads: int = 8,
expansion_ratio: int = 4,
norm_scheme: str = "pre",
transformer_activation: str = "gelu",
cell_emb_style: str = "cls",
pad_token_id: int = 0,
pad_value: float = 0.0,
num_bins: int = 51,
use_chem_token: bool = False,
attn_config: Optional[Dict] = None,
norm_config: Optional[Dict] = None,
init_config: Optional[Dict] = None,
gene_encoder_config: Optional[Dict] = None,
expression_encoder_config: Optional[Dict] = None,
expression_decoder_config: Optional[Dict] = None,
mvc_config: Optional[Dict] = None,
chemical_encoder_config: Optional[Dict] = None,
use_glu: bool = False,
return_gene_embeddings: bool = False,
standard_scale_outputs: bool = False,
keep_first_n_tokens: int = 1,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layers = n_layers
self.n_heads = n_heads
self.expansion_ratio = expansion_ratio
self.norm_scheme = norm_scheme
self.transformer_activation = transformer_activation
self.cell_emb_style = cell_emb_style
self.pad_value = pad_value
self.num_bins = num_bins
self.use_chem_token = use_chem_token
self.keep_first_n_tokens = keep_first_n_tokens
self.return_gene_embeddings = return_gene_embeddings
self.standard_scale_outputs = standard_scale_outputs
self.use_glu = use_glu
# Sub-configurations
self.attn_config = attn_config or {
"attn_type": "grouped_query_attention",
"attn_pdrop": 0.0,
"attn_impl": "flash",
"use_attn_mask": False,
"qk_ln": False,
"qk_gn": False,
"clip_qkv": None,
"softmax_scale": None,
}
self.norm_config = norm_config or {
"norm_type": "low_precision_layernorm",
"eps": 1e-5,
}
self.init_config = init_config or {
"name": "kaiming_normal_",
"fan_mode": "fan_in",
"init_nonlinearity": "relu",
"init_div_is_residual": True,
"emb_init_std": None,
"emb_init_uniform_lim": None,
"init_std": None,
"init_gain": 0.0,
}
self.gene_encoder_config = gene_encoder_config or {
"use_norm": False,
}
self.expression_encoder_config = expression_encoder_config or {
"input_emb_style": "continuous",
"dropout": 0.1,
"max_value": 512,
"activation": "relu",
"use_norm": False,
}
self.expression_decoder_config = expression_decoder_config or {
"n_outputs": 1,
"n_layers": 2,
"activation": "leaky_relu",
}
self.mvc_config = mvc_config
self.chemical_encoder_config = chemical_encoder_config
@classmethod
def from_yaml_configs(cls, model_config_dict: Dict, collator_config_dict: Dict) -> "TXConfig":
"""
Create TXConfig from model_config.yml and collator_config.yml dictionaries
Args:
model_config_dict: Dictionary from model_config.yml
collator_config_dict: Dictionary from collator_config.yml
Returns:
TXConfig instance
"""
return cls(
vocab_size=model_config_dict.get("vocab_size"),
d_model=model_config_dict.get("d_model"),
n_layers=model_config_dict.get("n_layers"),
n_heads=model_config_dict.get("n_heads"),
expansion_ratio=model_config_dict.get("expansion_ratio"),
norm_scheme=model_config_dict.get("norm_scheme", "pre"),
transformer_activation=model_config_dict.get("transformer_activation", "gelu"),
cell_emb_style=model_config_dict.get("cell_emb_style", "cls"),
pad_token_id=collator_config_dict.get("pad_token_id", 0),
pad_value=collator_config_dict.get("pad_value", 0.0),
num_bins=collator_config_dict.get("num_bins", 51),
use_chem_token=collator_config_dict.get("use_chem_token", False),
attn_config=model_config_dict.get("attn_config"),
norm_config=model_config_dict.get("norm_config"),
init_config=model_config_dict.get("init_config"),
gene_encoder_config=model_config_dict.get("gene_encoder"),
expression_encoder_config=model_config_dict.get("expression_encoder"),
expression_decoder_config=model_config_dict.get("expression_decoder"),
mvc_config=model_config_dict.get("mvc"),
chemical_encoder_config=model_config_dict.get("chemical_encoder"),
use_glu=model_config_dict.get("use_glu", False),
return_gene_embeddings=model_config_dict.get("return_gene_embeddings", False),
standard_scale_outputs=model_config_dict.get("standard_scale_outputs", False),
keep_first_n_tokens=collator_config_dict.get("keep_first_n_tokens", 1),
)