# Copyright (C) Tahoe Therapeutics 2025. All rights reserved. """ TXModel - Complete Standalone Implementation for HuggingFace All code in one file - requires ONLY: transformers, torch, safetensors """ import math from typing import Optional, Dict, Any, Tuple, Union import torch import torch.nn.functional as F from torch import Tensor, nn from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import BaseModelOutput # ============================================================================= # CONFIGURATION # ============================================================================= class TXConfig(PretrainedConfig): """Configuration for TXModel""" model_type = "tx_model" def __init__( self, vocab_size: int = 30000, d_model: int = 512, n_layers: int = 12, n_heads: int = 8, expansion_ratio: int = 4, norm_scheme: str = "pre", transformer_activation: str = "gelu", cell_emb_style: str = "cls", pad_token_id: int = 0, pad_value: float = 0.0, num_bins: int = 51, use_chem_token: bool = False, attn_config: Optional[Dict] = None, norm_config: Optional[Dict] = None, gene_encoder_config: Optional[Dict] = None, expression_encoder_config: Optional[Dict] = None, expression_decoder_config: Optional[Dict] = None, mvc_config: Optional[Dict] = None, chemical_encoder_config: Optional[Dict] = None, use_glu: bool = False, return_gene_embeddings: bool = False, standard_scale_outputs: bool = False, keep_first_n_tokens: int = 1, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.expansion_ratio = expansion_ratio self.norm_scheme = norm_scheme self.transformer_activation = transformer_activation self.cell_emb_style = cell_emb_style self.pad_value = pad_value self.num_bins = num_bins self.use_chem_token = use_chem_token self.keep_first_n_tokens = keep_first_n_tokens self.return_gene_embeddings = return_gene_embeddings self.standard_scale_outputs = standard_scale_outputs self.use_glu = use_glu self.attn_config = attn_config or {} self.norm_config = norm_config or {} self.gene_encoder_config = gene_encoder_config or {} self.expression_encoder_config = expression_encoder_config or {} self.expression_decoder_config = expression_decoder_config or {} self.mvc_config = mvc_config self.chemical_encoder_config = chemical_encoder_config # ============================================================================= # MODEL BLOCKS # ============================================================================= class MultiheadAttention(nn.Module): """Multi-head attention with grouped query support""" def __init__( self, d_model: int, n_heads: int, kv_n_heads: Optional[int] = None, dropout: float = 0.0, device: Optional[str] = None, ): super().__init__() self.d_model = d_model self.n_heads = n_heads self.kv_n_heads = kv_n_heads if kv_n_heads is not None else n_heads self.head_dim = d_model // n_heads self.dropout = dropout self.n_rep = n_heads // self.kv_n_heads self.q_proj = nn.Linear(d_model, d_model, device=device) self.k_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, device=device) self.v_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, device=device) self.out_proj = nn.Linear(d_model, d_model, device=device) self.attn_dropout = nn.Dropout(dropout) def forward( self, x: Tensor, key_padding_mask: Optional[Tensor] = None, **kwargs ) -> Tuple[Tensor, None, None]: batch_size, seq_len, _ = x.shape q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim).transpose(1, 2) if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) scale = 1.0 / math.sqrt(self.head_dim) attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale if key_padding_mask is not None: mask = key_padding_mask.unsqueeze(1).unsqueeze(2) attn_scores = attn_scores.masked_fill(~mask, float('-inf')) attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.attn_dropout(attn_weights) output = torch.matmul(attn_weights, v) output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) output = self.out_proj(output) return output, None, None class TXBlock(nn.Module): """Transformer encoder block""" def __init__( self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict] = None, norm_config: Optional[Dict] = None, dropout: float = 0.0, activation: str = "gelu", device: Optional[str] = None, norm_scheme: str = "pre", use_glu: bool = False, **kwargs ): super().__init__() attn_config = attn_config or {} norm_config = norm_config or {} self.d_model = d_model self.n_heads = n_heads self.norm_scheme = norm_scheme self.use_glu = use_glu kv_n_heads = attn_config.get("kv_n_heads", n_heads) self.self_attn = MultiheadAttention( d_model=d_model, n_heads=n_heads, kv_n_heads=kv_n_heads, dropout=attn_config.get("attn_pdrop", 0.0), device=device, ) dim_feedforward = d_model * expansion_ratio self.up_proj = nn.Linear(d_model, dim_feedforward, device=device) self.down_proj = nn.Linear(dim_feedforward, d_model, device=device) if use_glu: self.gate_proj = nn.Linear(d_model, dim_feedforward, device=device) eps = norm_config.get("eps", 1e-5) self.norm1 = nn.LayerNorm(d_model, eps=eps, device=device) self.norm2 = nn.LayerNorm(d_model, eps=eps, device=device) self.post_sa_dropout = nn.Dropout(dropout) self.post_ffn_dropout = nn.Dropout(dropout) self.activation = { "gelu": nn.GELU(), "relu": nn.ReLU(), "silu": nn.SiLU(), "leaky_relu": nn.LeakyReLU(), }.get(activation, nn.GELU()) def forward( self, x: Tensor, key_padding_mask: Optional[Tensor] = None, **kwargs ) -> Tensor: if self.norm_scheme == "pre": x = x + self._sa_block(self.norm1(x), key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x def _sa_block(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor: x, _, _ = self.self_attn(x, key_padding_mask=key_padding_mask) return self.post_sa_dropout(x) def _ff_block(self, x: Tensor) -> Tensor: if self.use_glu: x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) else: x = self.down_proj(self.activation(self.up_proj(x))) return self.post_ffn_dropout(x) class TXEncoder(nn.Module): """Stack of transformer encoder layers""" def __init__( self, encoder_layer: TXBlock, num_layers: int, use_norm: bool = False, norm_config: Optional[Dict] = None, **kwargs ): super().__init__() norm_config = norm_config or {} self.layers = nn.ModuleList([ TXBlock( d_model=encoder_layer.d_model, n_heads=encoder_layer.n_heads, expansion_ratio=encoder_layer.up_proj.out_features // encoder_layer.d_model, norm_scheme=encoder_layer.norm_scheme, use_glu=encoder_layer.use_glu, ) for _ in range(num_layers) ]) self.use_norm = use_norm if use_norm: eps = norm_config.get("eps", 1e-5) self.norm = nn.LayerNorm(encoder_layer.d_model, eps=eps) def forward( self, total_embs: Tensor, key_padding_mask: Optional[Tensor] = None, output_hidden_states: bool = False, ) -> Tuple[Tensor, Optional[list]]: x = total_embs hidden_states = [] if output_hidden_states else None for layer in self.layers: x = layer(x, key_padding_mask=key_padding_mask) if output_hidden_states: hidden_states.append(x) if self.use_norm: x = self.norm(x) return x, hidden_states class GeneEncoder(nn.Module): """Gene embedding encoder""" def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: int = 0, use_norm: bool = False, **kwargs ): super().__init__() self.use_norm = use_norm self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) self.project = nn.Identity() if self.use_norm: self.enc_norm = nn.LayerNorm(embedding_dim) def forward(self, x: Tensor) -> Tensor: x = self.embedding(x) x = self.project(x) if self.use_norm: x = self.enc_norm(x) return x class ContinuousValueEncoder(nn.Module): """Encode continuous expression values""" def __init__( self, d_model: int, dropout: float = 0.1, max_value: int = 512, activation: str = "relu", use_norm: bool = False, ): super().__init__() self.dropout = nn.Dropout(p=dropout) self.linear1 = nn.Linear(1, d_model) self.activation = {"relu": nn.ReLU(), "gelu": nn.GELU(), "leaky_relu": nn.LeakyReLU()}.get(activation, nn.ReLU()) self.linear2 = nn.Linear(d_model, d_model) self.use_norm = use_norm if use_norm: self.norm = nn.LayerNorm(d_model) self.max_value = max_value def forward(self, x: Tensor) -> Tensor: x = x.unsqueeze(-1) x = torch.clamp(x, max=self.max_value) x = self.activation(self.linear1(x)) x = self.linear2(x) if self.use_norm: x = self.norm(x) return self.dropout(x) class ExprDecoder(nn.Module): """Expression value decoder""" def __init__( self, d_model: int, n_outputs: int = 1, n_layers: int = 2, activation: str = "leaky_relu", ): super().__init__() self.activation = {"leaky_relu": nn.LeakyReLU(), "relu": nn.ReLU(), "gelu": nn.GELU()}.get(activation, nn.LeakyReLU()) self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_layers)]) self.out_proj = nn.Linear(d_model, n_outputs) def forward(self, x: Tensor) -> Dict[str, Tensor]: for layer in self.linear_layers: x = self.activation(layer(x)) pred_value = self.out_proj(x) if pred_value.shape[-1] == 1: pred_value = pred_value.squeeze(-1) return {"pred": pred_value} class MVCDecoder(nn.Module): """Masked value prediction decoder""" def __init__( self, d_model: int, arch_style: str = "inner product", query_activation: str = "sigmoid", scaled_dot_product: bool = False, ): super().__init__() self.scaled_dot_product = scaled_dot_product self.gene2query = nn.Linear(d_model, d_model) self.query_activation = {"sigmoid": nn.Sigmoid(), "relu": nn.ReLU(), "tanh": nn.Tanh()}.get(query_activation, nn.Sigmoid()) self.W = nn.Linear(d_model, d_model, bias=False) self.arch_style = arch_style def forward(self, cell_emb: Tensor, gene_embs: Tensor) -> Dict[str, Tensor]: query_vecs = self.query_activation(self.gene2query(gene_embs)) cell_emb = cell_emb.unsqueeze(2) pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) if self.scaled_dot_product: pred_value = pred_value / torch.sqrt(torch.tensor(query_vecs.shape[-1], dtype=pred_value.dtype)) return {"pred": pred_value} # ============================================================================= # MAIN MODEL # ============================================================================= class TXModel(nn.Module): """Transformer model for genomic data""" def __init__(self, config: TXConfig): super().__init__() self.config = config self.gene_encoder = GeneEncoder( config.vocab_size, config.d_model, padding_idx=config.pad_token_id, use_norm=config.gene_encoder_config.get("use_norm", False), ) self.flag_encoder = nn.Embedding(2, config.d_model) self.expression_encoder = ContinuousValueEncoder( d_model=config.d_model, dropout=config.expression_encoder_config.get("dropout", 0.1), max_value=config.expression_encoder_config.get("max_value", 512), activation=config.expression_encoder_config.get("activation", "relu"), use_norm=config.expression_encoder_config.get("use_norm", False), ) encoder_layer = TXBlock( d_model=config.d_model, n_heads=config.n_heads, expansion_ratio=config.expansion_ratio, attn_config=config.attn_config, norm_config=config.norm_config, activation=config.transformer_activation, norm_scheme=config.norm_scheme, use_glu=config.use_glu, ) self.transformer_encoder = TXEncoder( encoder_layer, config.n_layers, use_norm=config.norm_scheme == "pre", norm_config=config.norm_config, ) self.expression_decoder = ExprDecoder( d_model=config.d_model, n_outputs=config.expression_decoder_config.get("n_outputs", 1), n_layers=config.expression_decoder_config.get("n_layers", 2), activation=config.expression_decoder_config.get("activation", "leaky_relu"), ) if config.mvc_config is not None: self.mvc_decoder = MVCDecoder( d_model=config.d_model, arch_style=config.mvc_config.get("arch_style", "inner product"), query_activation=config.mvc_config.get("query_activation", "sigmoid"), scaled_dot_product=config.mvc_config.get("scaled_dot_product", False), ) else: self.mvc_decoder = None def forward( self, genes: Tensor, values: Tensor, gen_masks: Tensor, key_padding_mask: Tensor, skip_decoders: bool = False, output_hidden_states: bool = False, ) -> dict: # Encode token_embs = self.gene_encoder(genes) token_values = self.expression_encoder(values) token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0) flag = self.flag_encoder(torch.tensor(1, device=token_embs.device)).reshape(1, 1, -1) flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag total_embs = token_embs + token_values + flag_embs self.cur_gene_token_embs = token_embs # Transform transformer_output, hidden_states = self.transformer_encoder( total_embs=total_embs, key_padding_mask=key_padding_mask, output_hidden_states=output_hidden_states, ) # Cell embedding cell_emb = transformer_output[:, 0, :] output = { "transformer_output": transformer_output, "cell_emb": cell_emb, } if output_hidden_states: output["hidden_states"] = hidden_states if skip_decoders: return output # Decode expr_output = self.expression_decoder(transformer_output) output["expr_preds"] = expr_output["pred"] if self.mvc_decoder is not None: mvc_output = self.mvc_decoder(cell_emb, self.cur_gene_token_embs) output["mvc_output"] = mvc_output["pred"] return output # ============================================================================= # HUGGINGFACE WRAPPER # ============================================================================= class TXPreTrainedModel(PreTrainedModel): """Base class for TXModel""" config_class = TXConfig base_model_prefix = "tx_model" supports_gradient_checkpointing = False def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class TXModelForHF(TXPreTrainedModel): """ HuggingFace-compatible TXModel Requires ONLY: transformers, torch, safetensors """ def __init__(self, config: TXConfig): super().__init__(config) self.tx_model = TXModel(config) self.post_init() def forward( self, genes: torch.Tensor, values: torch.Tensor, gen_masks: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None, skip_decoders: bool = False, output_hidden_states: bool = False, return_dict: bool = True, **kwargs ) -> Union[Tuple, BaseModelOutput]: if key_padding_mask is None: key_padding_mask = ~genes.eq(self.config.pad_token_id) outputs = self.tx_model( genes=genes, values=values, gen_masks=gen_masks, key_padding_mask=key_padding_mask, skip_decoders=skip_decoders, output_hidden_states=output_hidden_states, ) if not return_dict: return tuple(v for v in outputs.values()) return BaseModelOutput( last_hidden_state=outputs.get("cell_emb"), hidden_states=outputs.get("hidden_states") if output_hidden_states else None, ) def get_input_embeddings(self): return self.tx_model.gene_encoder.embedding def set_input_embeddings(self, value): self.tx_model.gene_encoder.embedding = value # Aliases TXForCausalLM = TXModelForHF AutoModelForCausalLM = TXModelForHF