# Copyright (C) Tahoe Therapeutics 2025. All rights reserved. """ Standalone implementation of TXModel blocks without external dependencies. Only requires: torch, transformers """ import math from typing import Optional, Dict, Any, Tuple import torch import torch.nn.functional as F from torch import Tensor, nn class MultiheadAttention(nn.Module): """Standard multi-head attention implementation""" def __init__( self, d_model: int, n_heads: int, kv_n_heads: Optional[int] = None, dropout: float = 0.0, bias: bool = True, device: Optional[str] = None, ): super().__init__() self.d_model = d_model self.n_heads = n_heads self.kv_n_heads = kv_n_heads if kv_n_heads is not None else n_heads self.head_dim = d_model // n_heads self.dropout = dropout # Grouped Query Attention support self.n_rep = n_heads // self.kv_n_heads self.q_proj = nn.Linear(d_model, d_model, bias=bias, device=device) self.k_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, bias=bias, device=device) self.v_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, bias=bias, device=device) self.out_proj = nn.Linear(d_model, d_model, bias=bias, device=device) self.attn_dropout = nn.Dropout(dropout) def forward( self, x: Tensor, attn_bias: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, is_causal: bool = False, **kwargs ) -> Tuple[Tensor, None, None]: batch_size, seq_len, _ = x.shape # Project queries, keys, values q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim) k = self.k_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim) v = self.v_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim) # Transpose for attention: (batch, heads, seq, head_dim) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # Repeat k/v for grouped query attention if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) # Scaled dot-product attention scale = 1.0 / math.sqrt(self.head_dim) attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale # Apply attention bias if provided if attn_bias is not None: attn_scores = attn_scores + attn_bias # Apply key padding mask if key_padding_mask is not None: # key_padding_mask: (batch, seq_len) with True for valid positions # Convert to attention mask: (batch, 1, 1, seq_len) mask = key_padding_mask.unsqueeze(1).unsqueeze(2) attn_scores = attn_scores.masked_fill(~mask, float('-inf')) # Apply causal mask if needed if is_causal: causal_mask = torch.triu( torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), diagonal=1 ) attn_scores = attn_scores.masked_fill(causal_mask, float('-inf')) # Softmax and dropout attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.attn_dropout(attn_weights) # Apply attention to values output = torch.matmul(attn_weights, v) # Reshape and project output output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) output = self.out_proj(output) return output, None, None class TXBlock(nn.Module): """Transformer encoder block with pre/post normalization support""" def __init__( self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict] = None, norm_config: Optional[Dict] = None, dropout: Optional[float] = 0.0, activation: Optional[str] = "gelu", device: Optional[str] = None, norm_scheme: str = "pre", use_glu: bool = False, **kwargs: Any, ) -> None: super().__init__() if attn_config is None: attn_config = {} if norm_config is None: norm_config = {} self.d_model = d_model self.n_heads = n_heads self.device = device self.norm_scheme = norm_scheme self.use_glu = use_glu # Attention kv_n_heads = attn_config.get("kv_n_heads", n_heads) self.self_attn = MultiheadAttention( d_model=d_model, n_heads=n_heads, kv_n_heads=kv_n_heads, dropout=attn_config.get("attn_pdrop", 0.0), device=device, ) # FFN dim_feedforward = d_model * expansion_ratio self.up_proj = nn.Linear(d_model, dim_feedforward, device=device) self.down_proj = nn.Linear(dim_feedforward, d_model, device=device) if use_glu: self.gate_proj = nn.Linear(d_model, dim_feedforward, device=device) # Normalization eps = norm_config.get("eps", 1e-5) self.norm1 = nn.LayerNorm(d_model, eps=eps, device=device) self.norm2 = nn.LayerNorm(d_model, eps=eps, device=device) # Dropout self.post_sa_dropout = nn.Dropout(dropout) self.post_ffn_dropout = nn.Dropout(dropout) # Activation self.activation = self._get_activation_fn(activation) if norm_scheme not in ["pre", "post"]: raise ValueError("norm_scheme must be either pre or post") @staticmethod def _get_activation_fn(activation: str): if activation == "gelu": return nn.GELU() elif activation == "relu": return nn.ReLU() elif activation == "silu" or activation == "swish": return nn.SiLU() elif activation == "leaky_relu": return nn.LeakyReLU() else: raise ValueError(f"Unknown activation: {activation}") def forward( self, x: Tensor, attn_bias: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, **kwargs ) -> Tensor: if self.norm_scheme == "pre": # Pre-norm: norm -> attention -> add x = x + self._sa_block( self.norm1(x), attn_bias=attn_bias, key_padding_mask=key_padding_mask, ) x = x + self._ff_block(self.norm2(x)) else: # Post-norm: attention -> add -> norm x = self.norm1( x + self._sa_block( x, attn_bias=attn_bias, key_padding_mask=key_padding_mask, ) ) x = self.norm2(x + self._ff_block(x)) return x def _sa_block( self, x: Tensor, attn_bias: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, ) -> Tensor: x, _, _ = self.self_attn( x, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=False, ) return self.post_sa_dropout(x) def _ff_block(self, x: Tensor) -> Tensor: if self.use_glu: # GLU variant: (gate * activation(x)) * up(x) x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) else: # Standard FFN x = self.down_proj(self.activation(self.up_proj(x))) return self.post_ffn_dropout(x) class TXEncoder(nn.Module): """Stack of transformer encoder layers""" def __init__( self, encoder_layer: TXBlock, num_layers: int, use_norm: bool = False, norm_config: Optional[Dict] = None, attn_config: Optional[Dict] = None, ): super().__init__() if norm_config is None: norm_config = {} # Clone the layer self.layers = nn.ModuleList([ TXBlock( d_model=encoder_layer.d_model, n_heads=encoder_layer.n_heads, expansion_ratio=encoder_layer.up_proj.out_features // encoder_layer.d_model, attn_config=attn_config, norm_config=norm_config, activation="gelu", device=encoder_layer.device, norm_scheme=encoder_layer.norm_scheme, use_glu=encoder_layer.use_glu, ) for _ in range(num_layers) ]) self.use_norm = use_norm if use_norm: eps = norm_config.get("eps", 1e-5) self.norm = nn.LayerNorm(encoder_layer.d_model, eps=eps) def forward( self, total_embs: Tensor, key_padding_mask: Optional[Tensor] = None, output_hidden_states: bool = False, ) -> Tuple[Tensor, Optional[list]]: x = total_embs hidden_states = [] if output_hidden_states else None for layer in self.layers: x = layer( x, attn_bias=None, key_padding_mask=key_padding_mask, ) if output_hidden_states: hidden_states.append(x) if self.use_norm: x = self.norm(x) return x, hidden_states class GeneEncoder(nn.Module): """Gene embedding with optional extra embeddings""" def __init__( self, num_embeddings: int, embedding_dim: int, padding_idx: int = 0, use_norm: bool = False, gene_encoder_cfg: Optional[Dict] = None, ): super().__init__() if gene_encoder_cfg is None: gene_encoder_cfg = {} self.use_norm = use_norm self.embedding = nn.Embedding( num_embeddings, embedding_dim, padding_idx=padding_idx, ) # For now, no extra embeddings in standalone version self.project = nn.Identity() if self.use_norm: self.enc_norm = nn.LayerNorm(embedding_dim) def forward(self, x: Tensor) -> Tensor: x = self.embedding(x) x = self.project(x) if self.use_norm: x = self.enc_norm(x) return x class ChemEncoder(nn.Module): """Chemical compound encoder""" def __init__( self, d_out: int, padding_idx: int = 0, activation: str = "leaky_relu", use_norm: bool = True, freeze: bool = False, num_drugs: int = 1000, fp_dim: int = 2048, ): super().__init__() # Initialize with zeros (user should load pretrained weights) drug_fps = torch.zeros((num_drugs, fp_dim), dtype=torch.float32) self.embedding = nn.Embedding.from_pretrained( drug_fps, padding_idx=padding_idx, freeze=freeze, ) self.fc = nn.Linear(fp_dim, d_out) if activation == "leaky_relu": self.activation = nn.LeakyReLU() elif activation == "relu": self.activation = nn.ReLU() elif activation == "gelu": self.activation = nn.GELU() else: self.activation = nn.ReLU() self.proj = nn.Linear(d_out, d_out) self.use_norm = use_norm if self.use_norm: self.norm = nn.LayerNorm(d_out) def forward(self, x: Tensor) -> Tensor: x = self.embedding(x) x = self.activation(self.fc(x)) x = self.proj(x) if self.use_norm: x = self.norm(x) return x class ContinuousValueEncoder(nn.Module): """Encode continuous values to embeddings""" def __init__( self, d_model: int, dropout: float = 0.1, max_value: int = 512, activation: str = "relu", use_norm: bool = False, ): super().__init__() self.dropout = nn.Dropout(p=dropout) self.linear1 = nn.Linear(1, d_model) if activation == "relu": self.activation = nn.ReLU() elif activation == "gelu": self.activation = nn.GELU() elif activation == "leaky_relu": self.activation = nn.LeakyReLU() else: self.activation = nn.ReLU() self.linear2 = nn.Linear(d_model, d_model) self.use_norm = use_norm if self.use_norm: self.norm = nn.LayerNorm(d_model) self.max_value = max_value def forward(self, x: Tensor) -> Tensor: # Expand last dimension x = x.unsqueeze(-1) # Clip to max value x = torch.clamp(x, max=self.max_value) # Project x = self.activation(self.linear1(x)) x = self.linear2(x) if self.use_norm: x = self.norm(x) return self.dropout(x) class ExprDecoder(nn.Module): """Expression value decoder""" def __init__( self, d_model: int, n_outputs: int = 1, n_layers: int = 2, activation: str = "leaky_relu", ): super().__init__() if activation == "leaky_relu": self.activation = nn.LeakyReLU() elif activation == "relu": self.activation = nn.ReLU() elif activation == "gelu": self.activation = nn.GELU() else: self.activation = nn.LeakyReLU() self.linear_layers = nn.ModuleList( [nn.Linear(d_model, d_model) for _ in range(n_layers)] ) self.out_proj = nn.Linear(d_model, n_outputs) def forward(self, x: Tensor) -> Dict[str, Tensor]: for layer in self.linear_layers: x = self.activation(layer(x)) pred_value = self.out_proj(x) if pred_value.shape[-1] == 1: pred_value = pred_value.squeeze(-1) return {"pred": pred_value} class MVCDecoder(nn.Module): """Masked value prediction decoder""" def __init__( self, d_model: int, arch_style: str = "inner product", query_activation: str = "sigmoid", scaled_dot_product: bool = False, ) -> None: super().__init__() self.scaled_dot_product = scaled_dot_product if arch_style == "inner product": self.gene2query = nn.Linear(d_model, d_model) if query_activation == "sigmoid": self.query_activation = nn.Sigmoid() elif query_activation == "relu": self.query_activation = nn.ReLU() elif query_activation == "tanh": self.query_activation = nn.Tanh() else: self.query_activation = nn.Sigmoid() self.W = nn.Linear(d_model, d_model, bias=False) else: raise ValueError(f"Unknown arch_style: {arch_style}") self.arch_style = arch_style def forward( self, cell_emb: Tensor, gene_embs: Tensor, ) -> Dict[str, Tensor]: if self.arch_style == "inner product": query_vecs = self.query_activation( self.gene2query(gene_embs) ) inner_product_dimension = query_vecs.shape[-1] cell_emb = cell_emb.unsqueeze(2) pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) if self.scaled_dot_product: pred_value = pred_value / torch.sqrt( torch.tensor(inner_product_dimension, dtype=pred_value.dtype) ) return {"pred": pred_value} else: raise ValueError(f"Unknown arch_style: {self.arch_style}")