# Copyright (C) Tahoe Therapeutics 2025. All rights reserved. """ Configuration class for TXModel compatible with HuggingFace Transformers """ from transformers import PretrainedConfig from typing import Optional, Dict, Any class TXConfig(PretrainedConfig): """ Configuration class for TXModel. This class stores the configuration of a TXModel, which is a Transformer-based model for genomic/biological sequence analysis. Args: vocab_size (int): Size of the vocabulary d_model (int): Dimensionality of the model embeddings n_layers (int): Number of transformer layers n_heads (int): Number of attention heads expansion_ratio (int): Expansion ratio for FFN norm_scheme (str): Normalization scheme ('pre' or 'post') transformer_activation (str): Activation function for transformer cell_emb_style (str): Cell embedding style ('cls', 'avg-pool', 'w-pool') pad_token_id (int): ID of the padding token pad_value (float): Value for padding num_bins (int): Number of bins for expression values use_chem_token (bool): Whether to use chemical token encoder attn_config (Dict): Attention configuration norm_config (Dict): Normalization configuration init_config (Dict): Initialization configuration gene_encoder_config (Dict): Gene encoder configuration expression_encoder_config (Dict): Expression encoder configuration expression_decoder_config (Dict): Expression decoder configuration mvc_config (Optional[Dict]): MVC decoder configuration chemical_encoder_config (Optional[Dict]): Chemical encoder configuration use_glu (bool): Whether to use GLU in FFN return_gene_embeddings (bool): Whether to return gene embeddings standard_scale_outputs (bool): Whether to scale outputs """ model_type = "tx_model" def __init__( self, vocab_size: int = 30000, d_model: int = 512, n_layers: int = 12, n_heads: int = 8, expansion_ratio: int = 4, norm_scheme: str = "pre", transformer_activation: str = "gelu", cell_emb_style: str = "cls", pad_token_id: int = 0, pad_value: float = 0.0, num_bins: int = 51, use_chem_token: bool = False, attn_config: Optional[Dict] = None, norm_config: Optional[Dict] = None, init_config: Optional[Dict] = None, gene_encoder_config: Optional[Dict] = None, expression_encoder_config: Optional[Dict] = None, expression_decoder_config: Optional[Dict] = None, mvc_config: Optional[Dict] = None, chemical_encoder_config: Optional[Dict] = None, use_glu: bool = False, return_gene_embeddings: bool = False, standard_scale_outputs: bool = False, keep_first_n_tokens: int = 1, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.expansion_ratio = expansion_ratio self.norm_scheme = norm_scheme self.transformer_activation = transformer_activation self.cell_emb_style = cell_emb_style self.pad_value = pad_value self.num_bins = num_bins self.use_chem_token = use_chem_token self.keep_first_n_tokens = keep_first_n_tokens self.return_gene_embeddings = return_gene_embeddings self.standard_scale_outputs = standard_scale_outputs self.use_glu = use_glu # Sub-configurations self.attn_config = attn_config or { "attn_type": "grouped_query_attention", "attn_pdrop": 0.0, "attn_impl": "flash", "use_attn_mask": False, "qk_ln": False, "qk_gn": False, "clip_qkv": None, "softmax_scale": None, } self.norm_config = norm_config or { "norm_type": "low_precision_layernorm", "eps": 1e-5, } self.init_config = init_config or { "name": "kaiming_normal_", "fan_mode": "fan_in", "init_nonlinearity": "relu", "init_div_is_residual": True, "emb_init_std": None, "emb_init_uniform_lim": None, "init_std": None, "init_gain": 0.0, } self.gene_encoder_config = gene_encoder_config or { "use_norm": False, } self.expression_encoder_config = expression_encoder_config or { "input_emb_style": "continuous", "dropout": 0.1, "max_value": 512, "activation": "relu", "use_norm": False, } self.expression_decoder_config = expression_decoder_config or { "n_outputs": 1, "n_layers": 2, "activation": "leaky_relu", } self.mvc_config = mvc_config self.chemical_encoder_config = chemical_encoder_config @classmethod def from_yaml_configs(cls, model_config_dict: Dict, collator_config_dict: Dict) -> "TXConfig": """ Create TXConfig from model_config.yml and collator_config.yml dictionaries Args: model_config_dict: Dictionary from model_config.yml collator_config_dict: Dictionary from collator_config.yml Returns: TXConfig instance """ return cls( vocab_size=model_config_dict.get("vocab_size"), d_model=model_config_dict.get("d_model"), n_layers=model_config_dict.get("n_layers"), n_heads=model_config_dict.get("n_heads"), expansion_ratio=model_config_dict.get("expansion_ratio"), norm_scheme=model_config_dict.get("norm_scheme", "pre"), transformer_activation=model_config_dict.get("transformer_activation", "gelu"), cell_emb_style=model_config_dict.get("cell_emb_style", "cls"), pad_token_id=collator_config_dict.get("pad_token_id", 0), pad_value=collator_config_dict.get("pad_value", 0.0), num_bins=collator_config_dict.get("num_bins", 51), use_chem_token=collator_config_dict.get("use_chem_token", False), attn_config=model_config_dict.get("attn_config"), norm_config=model_config_dict.get("norm_config"), init_config=model_config_dict.get("init_config"), gene_encoder_config=model_config_dict.get("gene_encoder"), expression_encoder_config=model_config_dict.get("expression_encoder"), expression_decoder_config=model_config_dict.get("expression_decoder"), mvc_config=model_config_dict.get("mvc"), chemical_encoder_config=model_config_dict.get("chemical_encoder"), use_glu=model_config_dict.get("use_glu", False), return_gene_embeddings=model_config_dict.get("return_gene_embeddings", False), standard_scale_outputs=model_config_dict.get("standard_scale_outputs", False), keep_first_n_tokens=collator_config_dict.get("keep_first_n_tokens", 1), )