# Copyright (C) Tahoe Therapeutics 2025. All rights reserved. """ Standalone implementation of TXModel without external dependencies. Only requires: torch, transformers, safetensors """ from typing import Optional, Union, Tuple import torch import torch.nn.functional as F from torch import Tensor, nn from blocks_standalone import ( ChemEncoder, ContinuousValueEncoder, ExprDecoder, GeneEncoder, MVCDecoder, TXBlock, TXEncoder, ) class TXModel(nn.Module): """Standalone Transformer model for genomic data""" def __init__( self, vocab_size: int, d_model: int, n_layers: int, n_heads: int, expansion_ratio: int, pad_token_id: int, pad_value: float, num_bins: int, norm_scheme: str = "pre", transformer_activation: str = "gelu", cell_emb_style: str = "cls", use_chem_token: bool = False, attn_config: Optional[dict] = None, norm_config: Optional[dict] = None, gene_encoder_config: Optional[dict] = None, expression_encoder_config: Optional[dict] = None, expression_decoder_config: Optional[dict] = None, mvc_config: Optional[dict] = None, chemical_encoder_config: Optional[dict] = None, use_glu: bool = False, return_gene_embeddings: bool = False, keep_first_n_tokens: int = 1, device: Optional[str] = None, ): super().__init__() self.model_type = "Transformer" self.device = device self.vocab_size = vocab_size self.n_layers = n_layers self.n_heads = n_heads self.d_model = d_model self.expansion_ratio = expansion_ratio self.norm_scheme = norm_scheme self.transformer_activation = transformer_activation self.use_chem_token = use_chem_token self.cell_emb_style = cell_emb_style self.pad_token_id = pad_token_id self.pad_value = pad_value self.n_input_bins = num_bins self.keep_first_n_tokens = keep_first_n_tokens self.return_gene_embeddings = return_gene_embeddings if attn_config is None: attn_config = {} if norm_config is None: norm_config = {} if gene_encoder_config is None: gene_encoder_config = {"use_norm": False} if expression_encoder_config is None: expression_encoder_config = {} if expression_decoder_config is None: expression_decoder_config = {} # Gene encoder self.gene_encoder = GeneEncoder( self.vocab_size, self.d_model, padding_idx=self.pad_token_id, use_norm=gene_encoder_config.get("use_norm", False), gene_encoder_cfg=gene_encoder_config, ) # Flag encoder self.flag_encoder = nn.Embedding(2, self.d_model) # Expression encoder self.expression_encoder = ContinuousValueEncoder( d_model=self.d_model, dropout=expression_encoder_config.get("dropout", 0.1), max_value=expression_encoder_config.get("max_value", 512), activation=expression_encoder_config.get("activation", "relu"), use_norm=expression_encoder_config.get("use_norm", False), ) # Chemical encoder (if needed) if self.use_chem_token: if chemical_encoder_config is None: chemical_encoder_config = {} self.chem_encoder = ChemEncoder( d_out=self.d_model, padding_idx=chemical_encoder_config.get("padding_idx", 0), activation=chemical_encoder_config.get("activation", "leaky_relu"), freeze=chemical_encoder_config.get("freeze", False), num_drugs=chemical_encoder_config.get("num_drugs", 1000), fp_dim=chemical_encoder_config.get("fp_dim", 2048), ) # Transformer encoder encoder_layer = TXBlock( d_model=self.d_model, n_heads=self.n_heads, expansion_ratio=self.expansion_ratio, attn_config=attn_config, norm_config=norm_config, activation=self.transformer_activation, device=self.device, norm_scheme=self.norm_scheme, use_glu=use_glu, ) self.transformer_encoder = TXEncoder( encoder_layer, self.n_layers, use_norm=self.norm_scheme == "pre", norm_config=norm_config, attn_config=attn_config, ) # Expression decoder self.expression_decoder = ExprDecoder( d_model=self.d_model, n_outputs=expression_decoder_config.get("n_outputs", 1), n_layers=expression_decoder_config.get("n_layers", 2), activation=expression_decoder_config.get("activation", "leaky_relu"), ) # MVC decoder (if configured) if mvc_config is not None: self.mvc_decoder = MVCDecoder( d_model=self.d_model, arch_style=mvc_config.get("arch_style", "inner product"), query_activation=mvc_config.get("query_activation", "sigmoid"), scaled_dot_product=mvc_config.get("scaled_dot_product", False), ) else: self.mvc_decoder = None def transformer_generate( self, genes: Tensor, values: Tensor, gen_masks: Tensor, key_padding_mask: Tensor, drug_ids: Optional[Tensor] = None, output_hidden_states: bool = False, ) -> Union[Tensor, Tuple[Tensor, list]]: # Encode genes token_embs = self.gene_encoder(genes) # Encode expression values token_values = self.expression_encoder(values) token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0) # Flag embeddings flag = self.flag_encoder( torch.tensor(1, device=token_embs.device) ).reshape(1, 1, -1) flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag # Combine embeddings total_embs = token_embs + token_values + flag_embs # Add chemical embedding if used if self.use_chem_token and drug_ids is not None: drug_embs = self.chem_encoder(drug_ids) total_embs[:, 1, :] = drug_embs # Store gene embeddings for MVC self.cur_gene_token_embs = token_embs # Pass through transformer output, hidden_states = self.transformer_encoder( total_embs=total_embs, key_padding_mask=key_padding_mask, output_hidden_states=output_hidden_states, ) return output, hidden_states def forward( self, genes: Tensor, values: Tensor, gen_masks: Tensor, key_padding_mask: Tensor, drug_ids: Optional[Tensor] = None, skip_decoders: bool = False, output_hidden_states: bool = False, ) -> dict: # Generate transformer output transformer_output, hidden_states = self.transformer_generate( genes, values, gen_masks, key_padding_mask, drug_ids, output_hidden_states ) # Prepare output dict output = { "transformer_output": transformer_output, } if output_hidden_states: output["hidden_states"] = hidden_states # Cell embedding (CLS token or pooling) if self.cell_emb_style == "cls": cell_emb = transformer_output[:, 0, :] elif self.cell_emb_style == "avg-pool": # Average over non-padding tokens mask = key_padding_mask.unsqueeze(-1).float() cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) elif self.cell_emb_style == "w-pool": # Weighted pooling (not implemented, use avg) mask = key_padding_mask.unsqueeze(-1).float() cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) else: cell_emb = transformer_output[:, 0, :] output["cell_emb"] = cell_emb # Return gene embeddings if requested if self.return_gene_embeddings: output["gene_embeddings"] = transformer_output # Skip decoders if requested if skip_decoders: return output # Expression decoder expr_output = self.expression_decoder(transformer_output) output["expr_preds"] = expr_output["pred"] # MVC decoder (if available) if self.mvc_decoder is not None: mvc_output = self.mvc_decoder( cell_emb, self.cur_gene_token_embs, ) output["mvc_output"] = mvc_output["pred"] return output @classmethod def from_pretrained(cls, model_path: str, **kwargs): """Load model from pretrained weights""" from safetensors.torch import load_file import json from pathlib import Path model_path = Path(model_path) # Load config with open(model_path / "config.json", "r") as f: config = json.load(f) # Create model model = cls( vocab_size=config["vocab_size"], d_model=config["d_model"], n_layers=config["n_layers"], n_heads=config["n_heads"], expansion_ratio=config["expansion_ratio"], pad_token_id=config["pad_token_id"], pad_value=config["pad_value"], num_bins=config["num_bins"], norm_scheme=config.get("norm_scheme", "pre"), transformer_activation=config.get("transformer_activation", "gelu"), cell_emb_style=config.get("cell_emb_style", "cls"), use_chem_token=config.get("use_chem_token", False), attn_config=config.get("attn_config"), norm_config=config.get("norm_config"), gene_encoder_config=config.get("gene_encoder_config"), expression_encoder_config=config.get("expression_encoder_config"), expression_decoder_config=config.get("expression_decoder_config"), mvc_config=config.get("mvc_config"), chemical_encoder_config=config.get("chemical_encoder_config"), use_glu=config.get("use_glu", False), return_gene_embeddings=config.get("return_gene_embeddings", False), keep_first_n_tokens=config.get("keep_first_n_tokens", 1), ) # Load weights state_dict = load_file(model_path / "model.safetensors") # Remove 'model.tx_model.' or 'tx_model.' prefix if present new_state_dict = {} for k, v in state_dict.items(): new_key = k if k.startswith('model.tx_model.'): new_key = k[14:] # Remove 'model.tx_model.' elif k.startswith('tx_model.'): new_key = k[9:] # Remove 'tx_model.' elif k.startswith('model.'): new_key = k[6:] # Remove 'model.' new_state_dict[new_key] = v model.load_state_dict(new_state_dict, strict=False) return model