alvawalt's picture
Upload 3 files
2803b0b verified
# ------------------------------------------------------------
# CancerTranscriptome-Mini-48M
# Model: Lightweight adaptation of BulkFormer
# Author: Walter Alvarado (NASA Ames Research Center)
# License: MIT
#
# References:
# (1) Boming Kang, Rui Fan, Meizheng Yi, Chunmei Cui, Qinghua Cui.
# “A large-scale foundation model for bulk transcriptomes.”
# bioRxiv (2025). doi:10.1101/2025.06.11.659222
#
# (2) Alvarado W. “CancerTranscriptome-Mini-48M: A compact cancer-
# focused BulkFormer derivative.” https://github.com/alwalt/BioFM
#
# Data Source:
# ARCHS4 Human RNA-seq v2.5 (Lachmann et al., Nat Commun 2018)
# ------------------------------------------------------------
import torch
import torch.nn as nn
from torch_geometric.nn.conv import GCNConv
from performer_pytorch import Performer
# Default model hyperparameters
model_params = {
"dim": 320,
"bins": 10,
"gb_repeat": 1,
"p_repeat": 2,
"bin_head": 8,
"full_head": 4,
"gene_length": 19357
}
# ------------------------------------------------------------
# Rotary Expression Embedding (REE)
# ------------------------------------------------------------
class PositionalExprEmbedding(nn.Module):
"""
Rotary Expression Embedding (REE):
Converts continuous gene expression values into a sinusoidal
embedding usable by Performer/Transformer blocks. Deterministic,
not learned. Masked positions (-10) → zero vector.
"""
def __init__(self, dim, mask_token=-10):
super().__init__()
self.mask_token = mask_token
self.inv_freq = nn.Parameter(
1.0 / (100 ** (torch.arange(0, dim, 2).float() / dim)),
requires_grad=False
)
def forward(self, x):
mask = (x == self.mask_token).nonzero(as_tuple=False)
x = torch.einsum("bi,j->bij", x, self.inv_freq)
x = torch.cat([x.sin(), x.cos()], dim=-1)
x[mask[:, 0], mask[:, 1]] = 0
return x
# ------------------------------------------------------------
# GBFormer Block (Graph + Local Performer + Global Performer)
# ------------------------------------------------------------
class GBFormer(nn.Module):
"""
A single GBFormer block:
- LayerNorm
- GCNConv (gene-gene propagation)
- Binning by learned importance score
- Local Performer per-bin
- Global Performer
"""
def __init__(self, dim, gene_length, bin_head, full_head, bins, p_repeat):
super().__init__()
self.dim = dim
self.bins = bins
self.bin_head = bin_head
self.full_head = full_head
self.p_repeat = p_repeat
self.layernorm = nn.LayerNorm(dim)
self.gcn = GCNConv(dim, dim, cached=True, add_self_loops=False)
# Learn scoring → assign gene to bin
self.which_bin = nn.Linear(dim, 1)
# Local Performer per bin
self.bin_layers = nn.ModuleList([
Performer(
dim=dim,
heads=bin_head,
depth=1,
dim_head=dim // bin_head,
attn_dropout=0.2,
ff_dropout=0.2
)
for _ in range(bins)
])
# Global Performer stack
self.global_layers = nn.Sequential(*[
Performer(
dim=dim,
heads=full_head,
depth=1,
dim_head=dim // full_head
)
for _ in range(p_repeat)
])
def forward(self, x, graph):
B, G, D = x.shape
x = self.layernorm(x)
x = x + self.gcn(x, graph) # residual GCN update
if self.bins > 0:
scores = self.which_bin(x).squeeze(-1) # [B, G]
order = torch.argsort(scores, dim=1, descending=True)
order_full = order.unsqueeze(-1).expand(-1, -1, D)
x_sorted = x.gather(1, order_full)
bin_size = (G - 1) // self.bins + 1
chunks = torch.split(x_sorted, bin_size, dim=1)
processed = [
layer(chunk)
for chunk, layer in zip(chunks, self.bin_layers)
]
x_cat = torch.cat(processed, dim=1)
x = torch.empty_like(x_cat).scatter_(1, order_full, x_cat)
x = self.global_layers(x)
return x
# ------------------------------------------------------------
# Full BulkFormer Model
# ------------------------------------------------------------
class BulkFormer(nn.Module):
"""
CancerTranscriptome-Mini-48M:
A compact BulkFormer-style masked-expression model.
Combines:
- ESM2 gene identity embeddings
- Rotary Expression Embeddings (REE)
- Graph Convolution (GCNConv)
- Local/global Performer attention
- Optional intermediate repr_layers for feature extraction
"""
def __init__(
self,
dim,
graph,
gene_emb,
gene_length,
bin_head=4,
full_head=4,
bins=10,
gb_repeat=1,
p_repeat=1
):
super().__init__()
self.dim = dim
self.graph = graph
self.gene_length = gene_length
# Identity embeddings from ESM2 (trainable projection)
self.gene_emb = nn.Parameter(gene_emb)
self.gene_proj = nn.Sequential(
nn.Linear(gene_emb.shape[1], 4 * dim),
nn.ReLU(),
nn.Linear(4 * dim, dim)
)
# REE for expression
self.expr_emb = PositionalExprEmbedding(dim)
# Pre-attention mixing layer
self.mix = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.ReLU(),
nn.Linear(4 * dim, dim)
)
# Stacked GBFormer blocks
self.gb_blocks = nn.ModuleList([
GBFormer(dim, gene_length, bin_head, full_head, bins, p_repeat)
for _ in range(gb_repeat)
])
self.final_norm = nn.LayerNorm(dim)
# Output head → scalar prediction per gene
self.head = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.ReLU(),
nn.Linear(4 * dim, 1),
nn.ReLU()
)
def forward(self, x, repr_layers=None):
B, G = x.shape
hidden = {}
x = (
self.expr_emb(x) +
self.gene_proj(self.gene_emb) +
torch.zeros(B, 1, self.dim, device=x.device) # no AE latent in this version
)
x = self.mix(x)
for i, block in enumerate(self.gb_blocks):
x = block(x, self.graph)
if repr_layers and i in repr_layers:
hidden[i] = x
x = self.final_norm(x)
out = self.head(x).squeeze(-1)
if repr_layers:
return out, hidden
return out