# ------------------------------------------------------------ # CancerTranscriptome-Mini-48M # Model: Lightweight adaptation of BulkFormer # Author: Walter Alvarado (NASA Ames Research Center) # License: MIT # # References: # (1) Boming Kang, Rui Fan, Meizheng Yi, Chunmei Cui, Qinghua Cui. # “A large-scale foundation model for bulk transcriptomes.” # bioRxiv (2025). doi:10.1101/2025.06.11.659222 # # (2) Alvarado W. “CancerTranscriptome-Mini-48M: A compact cancer- # focused BulkFormer derivative.” https://github.com/alwalt/BioFM # # Data Source: # ARCHS4 Human RNA-seq v2.5 (Lachmann et al., Nat Commun 2018) # ------------------------------------------------------------ import torch import torch.nn as nn from torch_geometric.nn.conv import GCNConv from performer_pytorch import Performer # Default model hyperparameters model_params = { "dim": 320, "bins": 10, "gb_repeat": 1, "p_repeat": 2, "bin_head": 8, "full_head": 4, "gene_length": 19357 } # ------------------------------------------------------------ # Rotary Expression Embedding (REE) # ------------------------------------------------------------ class PositionalExprEmbedding(nn.Module): """ Rotary Expression Embedding (REE): Converts continuous gene expression values into a sinusoidal embedding usable by Performer/Transformer blocks. Deterministic, not learned. Masked positions (-10) → zero vector. """ def __init__(self, dim, mask_token=-10): super().__init__() self.mask_token = mask_token self.inv_freq = nn.Parameter( 1.0 / (100 ** (torch.arange(0, dim, 2).float() / dim)), requires_grad=False ) def forward(self, x): mask = (x == self.mask_token).nonzero(as_tuple=False) x = torch.einsum("bi,j->bij", x, self.inv_freq) x = torch.cat([x.sin(), x.cos()], dim=-1) x[mask[:, 0], mask[:, 1]] = 0 return x # ------------------------------------------------------------ # GBFormer Block (Graph + Local Performer + Global Performer) # ------------------------------------------------------------ class GBFormer(nn.Module): """ A single GBFormer block: - LayerNorm - GCNConv (gene-gene propagation) - Binning by learned importance score - Local Performer per-bin - Global Performer """ def __init__(self, dim, gene_length, bin_head, full_head, bins, p_repeat): super().__init__() self.dim = dim self.bins = bins self.bin_head = bin_head self.full_head = full_head self.p_repeat = p_repeat self.layernorm = nn.LayerNorm(dim) self.gcn = GCNConv(dim, dim, cached=True, add_self_loops=False) # Learn scoring → assign gene to bin self.which_bin = nn.Linear(dim, 1) # Local Performer per bin self.bin_layers = nn.ModuleList([ Performer( dim=dim, heads=bin_head, depth=1, dim_head=dim // bin_head, attn_dropout=0.2, ff_dropout=0.2 ) for _ in range(bins) ]) # Global Performer stack self.global_layers = nn.Sequential(*[ Performer( dim=dim, heads=full_head, depth=1, dim_head=dim // full_head ) for _ in range(p_repeat) ]) def forward(self, x, graph): B, G, D = x.shape x = self.layernorm(x) x = x + self.gcn(x, graph) # residual GCN update if self.bins > 0: scores = self.which_bin(x).squeeze(-1) # [B, G] order = torch.argsort(scores, dim=1, descending=True) order_full = order.unsqueeze(-1).expand(-1, -1, D) x_sorted = x.gather(1, order_full) bin_size = (G - 1) // self.bins + 1 chunks = torch.split(x_sorted, bin_size, dim=1) processed = [ layer(chunk) for chunk, layer in zip(chunks, self.bin_layers) ] x_cat = torch.cat(processed, dim=1) x = torch.empty_like(x_cat).scatter_(1, order_full, x_cat) x = self.global_layers(x) return x # ------------------------------------------------------------ # Full BulkFormer Model # ------------------------------------------------------------ class BulkFormer(nn.Module): """ CancerTranscriptome-Mini-48M: A compact BulkFormer-style masked-expression model. Combines: - ESM2 gene identity embeddings - Rotary Expression Embeddings (REE) - Graph Convolution (GCNConv) - Local/global Performer attention - Optional intermediate repr_layers for feature extraction """ def __init__( self, dim, graph, gene_emb, gene_length, bin_head=4, full_head=4, bins=10, gb_repeat=1, p_repeat=1 ): super().__init__() self.dim = dim self.graph = graph self.gene_length = gene_length # Identity embeddings from ESM2 (trainable projection) self.gene_emb = nn.Parameter(gene_emb) self.gene_proj = nn.Sequential( nn.Linear(gene_emb.shape[1], 4 * dim), nn.ReLU(), nn.Linear(4 * dim, dim) ) # REE for expression self.expr_emb = PositionalExprEmbedding(dim) # Pre-attention mixing layer self.mix = nn.Sequential( nn.Linear(dim, 4 * dim), nn.ReLU(), nn.Linear(4 * dim, dim) ) # Stacked GBFormer blocks self.gb_blocks = nn.ModuleList([ GBFormer(dim, gene_length, bin_head, full_head, bins, p_repeat) for _ in range(gb_repeat) ]) self.final_norm = nn.LayerNorm(dim) # Output head → scalar prediction per gene self.head = nn.Sequential( nn.Linear(dim, 4 * dim), nn.ReLU(), nn.Linear(4 * dim, 1), nn.ReLU() ) def forward(self, x, repr_layers=None): B, G = x.shape hidden = {} x = ( self.expr_emb(x) + self.gene_proj(self.gene_emb) + torch.zeros(B, 1, self.dim, device=x.device) # no AE latent in this version ) x = self.mix(x) for i, block in enumerate(self.gb_blocks): x = block(x, self.graph) if repr_layers and i in repr_layers: hidden[i] = x x = self.final_norm(x) out = self.head(x).squeeze(-1) if repr_layers: return out, hidden return out