tx-model-standalone / blocks_standalone.py
Yuto2007's picture
Upload folder using huggingface_hub
e093a4b verified
# Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
"""
Standalone implementation of TXModel blocks without external dependencies.
Only requires: torch, transformers
"""
import math
from typing import Optional, Dict, Any, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
class MultiheadAttention(nn.Module):
"""Standard multi-head attention implementation"""
def __init__(
self,
d_model: int,
n_heads: int,
kv_n_heads: Optional[int] = None,
dropout: float = 0.0,
bias: bool = True,
device: Optional[str] = None,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads if kv_n_heads is not None else n_heads
self.head_dim = d_model // n_heads
self.dropout = dropout
# Grouped Query Attention support
self.n_rep = n_heads // self.kv_n_heads
self.q_proj = nn.Linear(d_model, d_model, bias=bias, device=device)
self.k_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, bias=bias, device=device)
self.v_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, bias=bias, device=device)
self.out_proj = nn.Linear(d_model, d_model, bias=bias, device=device)
self.attn_dropout = nn.Dropout(dropout)
def forward(
self,
x: Tensor,
attn_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
is_causal: bool = False,
**kwargs
) -> Tuple[Tensor, None, None]:
batch_size, seq_len, _ = x.shape
# Project queries, keys, values
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim)
# Transpose for attention: (batch, heads, seq, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Repeat k/v for grouped query attention
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
# Scaled dot-product attention
scale = 1.0 / math.sqrt(self.head_dim)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# Apply attention bias if provided
if attn_bias is not None:
attn_scores = attn_scores + attn_bias
# Apply key padding mask
if key_padding_mask is not None:
# key_padding_mask: (batch, seq_len) with True for valid positions
# Convert to attention mask: (batch, 1, 1, seq_len)
mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
# Apply causal mask if needed
if is_causal:
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
# Softmax and dropout
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
# Apply attention to values
output = torch.matmul(attn_weights, v)
# Reshape and project output
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
output = self.out_proj(output)
return output, None, None
class TXBlock(nn.Module):
"""Transformer encoder block with pre/post normalization support"""
def __init__(
self,
d_model: int,
n_heads: int,
expansion_ratio: int,
attn_config: Optional[Dict] = None,
norm_config: Optional[Dict] = None,
dropout: Optional[float] = 0.0,
activation: Optional[str] = "gelu",
device: Optional[str] = None,
norm_scheme: str = "pre",
use_glu: bool = False,
**kwargs: Any,
) -> None:
super().__init__()
if attn_config is None:
attn_config = {}
if norm_config is None:
norm_config = {}
self.d_model = d_model
self.n_heads = n_heads
self.device = device
self.norm_scheme = norm_scheme
self.use_glu = use_glu
# Attention
kv_n_heads = attn_config.get("kv_n_heads", n_heads)
self.self_attn = MultiheadAttention(
d_model=d_model,
n_heads=n_heads,
kv_n_heads=kv_n_heads,
dropout=attn_config.get("attn_pdrop", 0.0),
device=device,
)
# FFN
dim_feedforward = d_model * expansion_ratio
self.up_proj = nn.Linear(d_model, dim_feedforward, device=device)
self.down_proj = nn.Linear(dim_feedforward, d_model, device=device)
if use_glu:
self.gate_proj = nn.Linear(d_model, dim_feedforward, device=device)
# Normalization
eps = norm_config.get("eps", 1e-5)
self.norm1 = nn.LayerNorm(d_model, eps=eps, device=device)
self.norm2 = nn.LayerNorm(d_model, eps=eps, device=device)
# Dropout
self.post_sa_dropout = nn.Dropout(dropout)
self.post_ffn_dropout = nn.Dropout(dropout)
# Activation
self.activation = self._get_activation_fn(activation)
if norm_scheme not in ["pre", "post"]:
raise ValueError("norm_scheme must be either pre or post")
@staticmethod
def _get_activation_fn(activation: str):
if activation == "gelu":
return nn.GELU()
elif activation == "relu":
return nn.ReLU()
elif activation == "silu" or activation == "swish":
return nn.SiLU()
elif activation == "leaky_relu":
return nn.LeakyReLU()
else:
raise ValueError(f"Unknown activation: {activation}")
def forward(
self,
x: Tensor,
attn_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
**kwargs
) -> Tensor:
if self.norm_scheme == "pre":
# Pre-norm: norm -> attention -> add
x = x + self._sa_block(
self.norm1(x),
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
)
x = x + self._ff_block(self.norm2(x))
else:
# Post-norm: attention -> add -> norm
x = self.norm1(
x + self._sa_block(
x,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
)
)
x = self.norm2(x + self._ff_block(x))
return x
def _sa_block(
self,
x: Tensor,
attn_bias: Optional[Tensor] = None,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
x, _, _ = self.self_attn(
x,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=False,
)
return self.post_sa_dropout(x)
def _ff_block(self, x: Tensor) -> Tensor:
if self.use_glu:
# GLU variant: (gate * activation(x)) * up(x)
x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
else:
# Standard FFN
x = self.down_proj(self.activation(self.up_proj(x)))
return self.post_ffn_dropout(x)
class TXEncoder(nn.Module):
"""Stack of transformer encoder layers"""
def __init__(
self,
encoder_layer: TXBlock,
num_layers: int,
use_norm: bool = False,
norm_config: Optional[Dict] = None,
attn_config: Optional[Dict] = None,
):
super().__init__()
if norm_config is None:
norm_config = {}
# Clone the layer
self.layers = nn.ModuleList([
TXBlock(
d_model=encoder_layer.d_model,
n_heads=encoder_layer.n_heads,
expansion_ratio=encoder_layer.up_proj.out_features // encoder_layer.d_model,
attn_config=attn_config,
norm_config=norm_config,
activation="gelu",
device=encoder_layer.device,
norm_scheme=encoder_layer.norm_scheme,
use_glu=encoder_layer.use_glu,
)
for _ in range(num_layers)
])
self.use_norm = use_norm
if use_norm:
eps = norm_config.get("eps", 1e-5)
self.norm = nn.LayerNorm(encoder_layer.d_model, eps=eps)
def forward(
self,
total_embs: Tensor,
key_padding_mask: Optional[Tensor] = None,
output_hidden_states: bool = False,
) -> Tuple[Tensor, Optional[list]]:
x = total_embs
hidden_states = [] if output_hidden_states else None
for layer in self.layers:
x = layer(
x,
attn_bias=None,
key_padding_mask=key_padding_mask,
)
if output_hidden_states:
hidden_states.append(x)
if self.use_norm:
x = self.norm(x)
return x, hidden_states
class GeneEncoder(nn.Module):
"""Gene embedding with optional extra embeddings"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = 0,
use_norm: bool = False,
gene_encoder_cfg: Optional[Dict] = None,
):
super().__init__()
if gene_encoder_cfg is None:
gene_encoder_cfg = {}
self.use_norm = use_norm
self.embedding = nn.Embedding(
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
)
# For now, no extra embeddings in standalone version
self.project = nn.Identity()
if self.use_norm:
self.enc_norm = nn.LayerNorm(embedding_dim)
def forward(self, x: Tensor) -> Tensor:
x = self.embedding(x)
x = self.project(x)
if self.use_norm:
x = self.enc_norm(x)
return x
class ChemEncoder(nn.Module):
"""Chemical compound encoder"""
def __init__(
self,
d_out: int,
padding_idx: int = 0,
activation: str = "leaky_relu",
use_norm: bool = True,
freeze: bool = False,
num_drugs: int = 1000,
fp_dim: int = 2048,
):
super().__init__()
# Initialize with zeros (user should load pretrained weights)
drug_fps = torch.zeros((num_drugs, fp_dim), dtype=torch.float32)
self.embedding = nn.Embedding.from_pretrained(
drug_fps,
padding_idx=padding_idx,
freeze=freeze,
)
self.fc = nn.Linear(fp_dim, d_out)
if activation == "leaky_relu":
self.activation = nn.LeakyReLU()
elif activation == "relu":
self.activation = nn.ReLU()
elif activation == "gelu":
self.activation = nn.GELU()
else:
self.activation = nn.ReLU()
self.proj = nn.Linear(d_out, d_out)
self.use_norm = use_norm
if self.use_norm:
self.norm = nn.LayerNorm(d_out)
def forward(self, x: Tensor) -> Tensor:
x = self.embedding(x)
x = self.activation(self.fc(x))
x = self.proj(x)
if self.use_norm:
x = self.norm(x)
return x
class ContinuousValueEncoder(nn.Module):
"""Encode continuous values to embeddings"""
def __init__(
self,
d_model: int,
dropout: float = 0.1,
max_value: int = 512,
activation: str = "relu",
use_norm: bool = False,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
self.linear1 = nn.Linear(1, d_model)
if activation == "relu":
self.activation = nn.ReLU()
elif activation == "gelu":
self.activation = nn.GELU()
elif activation == "leaky_relu":
self.activation = nn.LeakyReLU()
else:
self.activation = nn.ReLU()
self.linear2 = nn.Linear(d_model, d_model)
self.use_norm = use_norm
if self.use_norm:
self.norm = nn.LayerNorm(d_model)
self.max_value = max_value
def forward(self, x: Tensor) -> Tensor:
# Expand last dimension
x = x.unsqueeze(-1)
# Clip to max value
x = torch.clamp(x, max=self.max_value)
# Project
x = self.activation(self.linear1(x))
x = self.linear2(x)
if self.use_norm:
x = self.norm(x)
return self.dropout(x)
class ExprDecoder(nn.Module):
"""Expression value decoder"""
def __init__(
self,
d_model: int,
n_outputs: int = 1,
n_layers: int = 2,
activation: str = "leaky_relu",
):
super().__init__()
if activation == "leaky_relu":
self.activation = nn.LeakyReLU()
elif activation == "relu":
self.activation = nn.ReLU()
elif activation == "gelu":
self.activation = nn.GELU()
else:
self.activation = nn.LeakyReLU()
self.linear_layers = nn.ModuleList(
[nn.Linear(d_model, d_model) for _ in range(n_layers)]
)
self.out_proj = nn.Linear(d_model, n_outputs)
def forward(self, x: Tensor) -> Dict[str, Tensor]:
for layer in self.linear_layers:
x = self.activation(layer(x))
pred_value = self.out_proj(x)
if pred_value.shape[-1] == 1:
pred_value = pred_value.squeeze(-1)
return {"pred": pred_value}
class MVCDecoder(nn.Module):
"""Masked value prediction decoder"""
def __init__(
self,
d_model: int,
arch_style: str = "inner product",
query_activation: str = "sigmoid",
scaled_dot_product: bool = False,
) -> None:
super().__init__()
self.scaled_dot_product = scaled_dot_product
if arch_style == "inner product":
self.gene2query = nn.Linear(d_model, d_model)
if query_activation == "sigmoid":
self.query_activation = nn.Sigmoid()
elif query_activation == "relu":
self.query_activation = nn.ReLU()
elif query_activation == "tanh":
self.query_activation = nn.Tanh()
else:
self.query_activation = nn.Sigmoid()
self.W = nn.Linear(d_model, d_model, bias=False)
else:
raise ValueError(f"Unknown arch_style: {arch_style}")
self.arch_style = arch_style
def forward(
self,
cell_emb: Tensor,
gene_embs: Tensor,
) -> Dict[str, Tensor]:
if self.arch_style == "inner product":
query_vecs = self.query_activation(
self.gene2query(gene_embs)
)
inner_product_dimension = query_vecs.shape[-1]
cell_emb = cell_emb.unsqueeze(2)
pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2)
if self.scaled_dot_product:
pred_value = pred_value / torch.sqrt(
torch.tensor(inner_product_dimension, dtype=pred_value.dtype)
)
return {"pred": pred_value}
else:
raise ValueError(f"Unknown arch_style: {self.arch_style}")