|
|
|
|
|
""" |
|
|
Configuration class for TXModel compatible with HuggingFace Transformers |
|
|
""" |
|
|
|
|
|
from transformers import PretrainedConfig |
|
|
from typing import Optional, Dict, Any |
|
|
|
|
|
|
|
|
class TXConfig(PretrainedConfig): |
|
|
""" |
|
|
Configuration class for TXModel. |
|
|
|
|
|
This class stores the configuration of a TXModel, which is a Transformer-based model |
|
|
for genomic/biological sequence analysis. |
|
|
|
|
|
Args: |
|
|
vocab_size (int): Size of the vocabulary |
|
|
d_model (int): Dimensionality of the model embeddings |
|
|
n_layers (int): Number of transformer layers |
|
|
n_heads (int): Number of attention heads |
|
|
expansion_ratio (int): Expansion ratio for FFN |
|
|
norm_scheme (str): Normalization scheme ('pre' or 'post') |
|
|
transformer_activation (str): Activation function for transformer |
|
|
cell_emb_style (str): Cell embedding style ('cls', 'avg-pool', 'w-pool') |
|
|
pad_token_id (int): ID of the padding token |
|
|
pad_value (float): Value for padding |
|
|
num_bins (int): Number of bins for expression values |
|
|
use_chem_token (bool): Whether to use chemical token encoder |
|
|
attn_config (Dict): Attention configuration |
|
|
norm_config (Dict): Normalization configuration |
|
|
init_config (Dict): Initialization configuration |
|
|
gene_encoder_config (Dict): Gene encoder configuration |
|
|
expression_encoder_config (Dict): Expression encoder configuration |
|
|
expression_decoder_config (Dict): Expression decoder configuration |
|
|
mvc_config (Optional[Dict]): MVC decoder configuration |
|
|
chemical_encoder_config (Optional[Dict]): Chemical encoder configuration |
|
|
use_glu (bool): Whether to use GLU in FFN |
|
|
return_gene_embeddings (bool): Whether to return gene embeddings |
|
|
standard_scale_outputs (bool): Whether to scale outputs |
|
|
""" |
|
|
|
|
|
model_type = "tx_model" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 30000, |
|
|
d_model: int = 512, |
|
|
n_layers: int = 12, |
|
|
n_heads: int = 8, |
|
|
expansion_ratio: int = 4, |
|
|
norm_scheme: str = "pre", |
|
|
transformer_activation: str = "gelu", |
|
|
cell_emb_style: str = "cls", |
|
|
pad_token_id: int = 0, |
|
|
pad_value: float = 0.0, |
|
|
num_bins: int = 51, |
|
|
use_chem_token: bool = False, |
|
|
attn_config: Optional[Dict] = None, |
|
|
norm_config: Optional[Dict] = None, |
|
|
init_config: Optional[Dict] = None, |
|
|
gene_encoder_config: Optional[Dict] = None, |
|
|
expression_encoder_config: Optional[Dict] = None, |
|
|
expression_decoder_config: Optional[Dict] = None, |
|
|
mvc_config: Optional[Dict] = None, |
|
|
chemical_encoder_config: Optional[Dict] = None, |
|
|
use_glu: bool = False, |
|
|
return_gene_embeddings: bool = False, |
|
|
standard_scale_outputs: bool = False, |
|
|
keep_first_n_tokens: int = 1, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(pad_token_id=pad_token_id, **kwargs) |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self.n_heads = n_heads |
|
|
self.expansion_ratio = expansion_ratio |
|
|
self.norm_scheme = norm_scheme |
|
|
self.transformer_activation = transformer_activation |
|
|
self.cell_emb_style = cell_emb_style |
|
|
self.pad_value = pad_value |
|
|
self.num_bins = num_bins |
|
|
self.use_chem_token = use_chem_token |
|
|
self.keep_first_n_tokens = keep_first_n_tokens |
|
|
self.return_gene_embeddings = return_gene_embeddings |
|
|
self.standard_scale_outputs = standard_scale_outputs |
|
|
self.use_glu = use_glu |
|
|
|
|
|
|
|
|
self.attn_config = attn_config or { |
|
|
"attn_type": "grouped_query_attention", |
|
|
"attn_pdrop": 0.0, |
|
|
"attn_impl": "flash", |
|
|
"use_attn_mask": False, |
|
|
"qk_ln": False, |
|
|
"qk_gn": False, |
|
|
"clip_qkv": None, |
|
|
"softmax_scale": None, |
|
|
} |
|
|
|
|
|
self.norm_config = norm_config or { |
|
|
"norm_type": "low_precision_layernorm", |
|
|
"eps": 1e-5, |
|
|
} |
|
|
|
|
|
self.init_config = init_config or { |
|
|
"name": "kaiming_normal_", |
|
|
"fan_mode": "fan_in", |
|
|
"init_nonlinearity": "relu", |
|
|
"init_div_is_residual": True, |
|
|
"emb_init_std": None, |
|
|
"emb_init_uniform_lim": None, |
|
|
"init_std": None, |
|
|
"init_gain": 0.0, |
|
|
} |
|
|
|
|
|
self.gene_encoder_config = gene_encoder_config or { |
|
|
"use_norm": False, |
|
|
} |
|
|
|
|
|
self.expression_encoder_config = expression_encoder_config or { |
|
|
"input_emb_style": "continuous", |
|
|
"dropout": 0.1, |
|
|
"max_value": 512, |
|
|
"activation": "relu", |
|
|
"use_norm": False, |
|
|
} |
|
|
|
|
|
self.expression_decoder_config = expression_decoder_config or { |
|
|
"n_outputs": 1, |
|
|
"n_layers": 2, |
|
|
"activation": "leaky_relu", |
|
|
} |
|
|
|
|
|
self.mvc_config = mvc_config |
|
|
self.chemical_encoder_config = chemical_encoder_config |
|
|
|
|
|
@classmethod |
|
|
def from_yaml_configs(cls, model_config_dict: Dict, collator_config_dict: Dict) -> "TXConfig": |
|
|
""" |
|
|
Create TXConfig from model_config.yml and collator_config.yml dictionaries |
|
|
|
|
|
Args: |
|
|
model_config_dict: Dictionary from model_config.yml |
|
|
collator_config_dict: Dictionary from collator_config.yml |
|
|
|
|
|
Returns: |
|
|
TXConfig instance |
|
|
""" |
|
|
return cls( |
|
|
vocab_size=model_config_dict.get("vocab_size"), |
|
|
d_model=model_config_dict.get("d_model"), |
|
|
n_layers=model_config_dict.get("n_layers"), |
|
|
n_heads=model_config_dict.get("n_heads"), |
|
|
expansion_ratio=model_config_dict.get("expansion_ratio"), |
|
|
norm_scheme=model_config_dict.get("norm_scheme", "pre"), |
|
|
transformer_activation=model_config_dict.get("transformer_activation", "gelu"), |
|
|
cell_emb_style=model_config_dict.get("cell_emb_style", "cls"), |
|
|
pad_token_id=collator_config_dict.get("pad_token_id", 0), |
|
|
pad_value=collator_config_dict.get("pad_value", 0.0), |
|
|
num_bins=collator_config_dict.get("num_bins", 51), |
|
|
use_chem_token=collator_config_dict.get("use_chem_token", False), |
|
|
attn_config=model_config_dict.get("attn_config"), |
|
|
norm_config=model_config_dict.get("norm_config"), |
|
|
init_config=model_config_dict.get("init_config"), |
|
|
gene_encoder_config=model_config_dict.get("gene_encoder"), |
|
|
expression_encoder_config=model_config_dict.get("expression_encoder"), |
|
|
expression_decoder_config=model_config_dict.get("expression_decoder"), |
|
|
mvc_config=model_config_dict.get("mvc"), |
|
|
chemical_encoder_config=model_config_dict.get("chemical_encoder"), |
|
|
use_glu=model_config_dict.get("use_glu", False), |
|
|
return_gene_embeddings=model_config_dict.get("return_gene_embeddings", False), |
|
|
standard_scale_outputs=model_config_dict.get("standard_scale_outputs", False), |
|
|
keep_first_n_tokens=collator_config_dict.get("keep_first_n_tokens", 1), |
|
|
) |
|
|
|