|
|
|
|
|
""" |
|
|
TXModel - Complete Standalone Implementation for HuggingFace |
|
|
All code in one file - requires ONLY: transformers, torch, safetensors |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Dict, Any, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor, nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TXConfig(PretrainedConfig): |
|
|
"""Configuration for TXModel""" |
|
|
|
|
|
model_type = "tx_model" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int = 30000, |
|
|
d_model: int = 512, |
|
|
n_layers: int = 12, |
|
|
n_heads: int = 8, |
|
|
expansion_ratio: int = 4, |
|
|
norm_scheme: str = "pre", |
|
|
transformer_activation: str = "gelu", |
|
|
cell_emb_style: str = "cls", |
|
|
pad_token_id: int = 0, |
|
|
pad_value: float = 0.0, |
|
|
num_bins: int = 51, |
|
|
use_chem_token: bool = False, |
|
|
attn_config: Optional[Dict] = None, |
|
|
norm_config: Optional[Dict] = None, |
|
|
gene_encoder_config: Optional[Dict] = None, |
|
|
expression_encoder_config: Optional[Dict] = None, |
|
|
expression_decoder_config: Optional[Dict] = None, |
|
|
mvc_config: Optional[Dict] = None, |
|
|
chemical_encoder_config: Optional[Dict] = None, |
|
|
use_glu: bool = False, |
|
|
return_gene_embeddings: bool = False, |
|
|
standard_scale_outputs: bool = False, |
|
|
keep_first_n_tokens: int = 1, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(pad_token_id=pad_token_id, **kwargs) |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self.n_heads = n_heads |
|
|
self.expansion_ratio = expansion_ratio |
|
|
self.norm_scheme = norm_scheme |
|
|
self.transformer_activation = transformer_activation |
|
|
self.cell_emb_style = cell_emb_style |
|
|
self.pad_value = pad_value |
|
|
self.num_bins = num_bins |
|
|
self.use_chem_token = use_chem_token |
|
|
self.keep_first_n_tokens = keep_first_n_tokens |
|
|
self.return_gene_embeddings = return_gene_embeddings |
|
|
self.standard_scale_outputs = standard_scale_outputs |
|
|
self.use_glu = use_glu |
|
|
|
|
|
self.attn_config = attn_config or {} |
|
|
self.norm_config = norm_config or {} |
|
|
self.gene_encoder_config = gene_encoder_config or {} |
|
|
self.expression_encoder_config = expression_encoder_config or {} |
|
|
self.expression_decoder_config = expression_decoder_config or {} |
|
|
self.mvc_config = mvc_config |
|
|
self.chemical_encoder_config = chemical_encoder_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
|
"""Multi-head attention with grouped query support""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
kv_n_heads: Optional[int] = None, |
|
|
dropout: float = 0.0, |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.kv_n_heads = kv_n_heads if kv_n_heads is not None else n_heads |
|
|
self.head_dim = d_model // n_heads |
|
|
self.dropout = dropout |
|
|
self.n_rep = n_heads // self.kv_n_heads |
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model, device=device) |
|
|
self.k_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, device=device) |
|
|
self.v_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, device=device) |
|
|
self.out_proj = nn.Linear(d_model, d_model, device=device) |
|
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: Tensor, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
**kwargs |
|
|
) -> Tuple[Tensor, None, None]: |
|
|
batch_size, seq_len, _ = x.shape |
|
|
|
|
|
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
if self.n_rep > 1: |
|
|
k = k.repeat_interleave(self.n_rep, dim=1) |
|
|
v = v.repeat_interleave(self.n_rep, dim=1) |
|
|
|
|
|
scale = 1.0 / math.sqrt(self.head_dim) |
|
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale |
|
|
|
|
|
if key_padding_mask is not None: |
|
|
mask = key_padding_mask.unsqueeze(1).unsqueeze(2) |
|
|
attn_scores = attn_scores.masked_fill(~mask, float('-inf')) |
|
|
|
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
|
|
|
output = torch.matmul(attn_weights, v) |
|
|
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) |
|
|
output = self.out_proj(output) |
|
|
|
|
|
return output, None, None |
|
|
|
|
|
|
|
|
class TXBlock(nn.Module): |
|
|
"""Transformer encoder block""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
expansion_ratio: int, |
|
|
attn_config: Optional[Dict] = None, |
|
|
norm_config: Optional[Dict] = None, |
|
|
dropout: float = 0.0, |
|
|
activation: str = "gelu", |
|
|
device: Optional[str] = None, |
|
|
norm_scheme: str = "pre", |
|
|
use_glu: bool = False, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
attn_config = attn_config or {} |
|
|
norm_config = norm_config or {} |
|
|
|
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.norm_scheme = norm_scheme |
|
|
self.use_glu = use_glu |
|
|
|
|
|
kv_n_heads = attn_config.get("kv_n_heads", n_heads) |
|
|
self.self_attn = MultiheadAttention( |
|
|
d_model=d_model, |
|
|
n_heads=n_heads, |
|
|
kv_n_heads=kv_n_heads, |
|
|
dropout=attn_config.get("attn_pdrop", 0.0), |
|
|
device=device, |
|
|
) |
|
|
|
|
|
dim_feedforward = d_model * expansion_ratio |
|
|
self.up_proj = nn.Linear(d_model, dim_feedforward, device=device) |
|
|
self.down_proj = nn.Linear(dim_feedforward, d_model, device=device) |
|
|
|
|
|
if use_glu: |
|
|
self.gate_proj = nn.Linear(d_model, dim_feedforward, device=device) |
|
|
|
|
|
eps = norm_config.get("eps", 1e-5) |
|
|
self.norm1 = nn.LayerNorm(d_model, eps=eps, device=device) |
|
|
self.norm2 = nn.LayerNorm(d_model, eps=eps, device=device) |
|
|
|
|
|
self.post_sa_dropout = nn.Dropout(dropout) |
|
|
self.post_ffn_dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.activation = { |
|
|
"gelu": nn.GELU(), |
|
|
"relu": nn.ReLU(), |
|
|
"silu": nn.SiLU(), |
|
|
"leaky_relu": nn.LeakyReLU(), |
|
|
}.get(activation, nn.GELU()) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: Tensor, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
**kwargs |
|
|
) -> Tensor: |
|
|
if self.norm_scheme == "pre": |
|
|
x = x + self._sa_block(self.norm1(x), key_padding_mask) |
|
|
x = x + self._ff_block(self.norm2(x)) |
|
|
else: |
|
|
x = self.norm1(x + self._sa_block(x, key_padding_mask)) |
|
|
x = self.norm2(x + self._ff_block(x)) |
|
|
return x |
|
|
|
|
|
def _sa_block(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor: |
|
|
x, _, _ = self.self_attn(x, key_padding_mask=key_padding_mask) |
|
|
return self.post_sa_dropout(x) |
|
|
|
|
|
def _ff_block(self, x: Tensor) -> Tensor: |
|
|
if self.use_glu: |
|
|
x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) |
|
|
else: |
|
|
x = self.down_proj(self.activation(self.up_proj(x))) |
|
|
return self.post_ffn_dropout(x) |
|
|
|
|
|
|
|
|
class TXEncoder(nn.Module): |
|
|
"""Stack of transformer encoder layers""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
encoder_layer: TXBlock, |
|
|
num_layers: int, |
|
|
use_norm: bool = False, |
|
|
norm_config: Optional[Dict] = None, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
norm_config = norm_config or {} |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
TXBlock( |
|
|
d_model=encoder_layer.d_model, |
|
|
n_heads=encoder_layer.n_heads, |
|
|
expansion_ratio=encoder_layer.up_proj.out_features // encoder_layer.d_model, |
|
|
norm_scheme=encoder_layer.norm_scheme, |
|
|
use_glu=encoder_layer.use_glu, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.use_norm = use_norm |
|
|
if use_norm: |
|
|
eps = norm_config.get("eps", 1e-5) |
|
|
self.norm = nn.LayerNorm(encoder_layer.d_model, eps=eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
total_embs: Tensor, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
output_hidden_states: bool = False, |
|
|
) -> Tuple[Tensor, Optional[list]]: |
|
|
x = total_embs |
|
|
hidden_states = [] if output_hidden_states else None |
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, key_padding_mask=key_padding_mask) |
|
|
if output_hidden_states: |
|
|
hidden_states.append(x) |
|
|
|
|
|
if self.use_norm: |
|
|
x = self.norm(x) |
|
|
|
|
|
return x, hidden_states |
|
|
|
|
|
|
|
|
class GeneEncoder(nn.Module): |
|
|
"""Gene embedding encoder""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_embeddings: int, |
|
|
embedding_dim: int, |
|
|
padding_idx: int = 0, |
|
|
use_norm: bool = False, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
self.use_norm = use_norm |
|
|
self.embedding = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
|
|
self.project = nn.Identity() |
|
|
|
|
|
if self.use_norm: |
|
|
self.enc_norm = nn.LayerNorm(embedding_dim) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.embedding(x) |
|
|
x = self.project(x) |
|
|
if self.use_norm: |
|
|
x = self.enc_norm(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ContinuousValueEncoder(nn.Module): |
|
|
"""Encode continuous expression values""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
dropout: float = 0.1, |
|
|
max_value: int = 512, |
|
|
activation: str = "relu", |
|
|
use_norm: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
self.linear1 = nn.Linear(1, d_model) |
|
|
self.activation = {"relu": nn.ReLU(), "gelu": nn.GELU(), "leaky_relu": nn.LeakyReLU()}.get(activation, nn.ReLU()) |
|
|
self.linear2 = nn.Linear(d_model, d_model) |
|
|
self.use_norm = use_norm |
|
|
if use_norm: |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.max_value = max_value |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = x.unsqueeze(-1) |
|
|
x = torch.clamp(x, max=self.max_value) |
|
|
x = self.activation(self.linear1(x)) |
|
|
x = self.linear2(x) |
|
|
if self.use_norm: |
|
|
x = self.norm(x) |
|
|
return self.dropout(x) |
|
|
|
|
|
|
|
|
class ExprDecoder(nn.Module): |
|
|
"""Expression value decoder""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_outputs: int = 1, |
|
|
n_layers: int = 2, |
|
|
activation: str = "leaky_relu", |
|
|
): |
|
|
super().__init__() |
|
|
self.activation = {"leaky_relu": nn.LeakyReLU(), "relu": nn.ReLU(), "gelu": nn.GELU()}.get(activation, nn.LeakyReLU()) |
|
|
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(n_layers)]) |
|
|
self.out_proj = nn.Linear(d_model, n_outputs) |
|
|
|
|
|
def forward(self, x: Tensor) -> Dict[str, Tensor]: |
|
|
for layer in self.linear_layers: |
|
|
x = self.activation(layer(x)) |
|
|
pred_value = self.out_proj(x) |
|
|
if pred_value.shape[-1] == 1: |
|
|
pred_value = pred_value.squeeze(-1) |
|
|
return {"pred": pred_value} |
|
|
|
|
|
|
|
|
class MVCDecoder(nn.Module): |
|
|
"""Masked value prediction decoder""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
arch_style: str = "inner product", |
|
|
query_activation: str = "sigmoid", |
|
|
scaled_dot_product: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.scaled_dot_product = scaled_dot_product |
|
|
self.gene2query = nn.Linear(d_model, d_model) |
|
|
self.query_activation = {"sigmoid": nn.Sigmoid(), "relu": nn.ReLU(), "tanh": nn.Tanh()}.get(query_activation, nn.Sigmoid()) |
|
|
self.W = nn.Linear(d_model, d_model, bias=False) |
|
|
self.arch_style = arch_style |
|
|
|
|
|
def forward(self, cell_emb: Tensor, gene_embs: Tensor) -> Dict[str, Tensor]: |
|
|
query_vecs = self.query_activation(self.gene2query(gene_embs)) |
|
|
cell_emb = cell_emb.unsqueeze(2) |
|
|
pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) |
|
|
|
|
|
if self.scaled_dot_product: |
|
|
pred_value = pred_value / torch.sqrt(torch.tensor(query_vecs.shape[-1], dtype=pred_value.dtype)) |
|
|
|
|
|
return {"pred": pred_value} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TXModel(nn.Module): |
|
|
"""Transformer model for genomic data""" |
|
|
|
|
|
def __init__(self, config: TXConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
self.gene_encoder = GeneEncoder( |
|
|
config.vocab_size, |
|
|
config.d_model, |
|
|
padding_idx=config.pad_token_id, |
|
|
use_norm=config.gene_encoder_config.get("use_norm", False), |
|
|
) |
|
|
|
|
|
self.flag_encoder = nn.Embedding(2, config.d_model) |
|
|
|
|
|
self.expression_encoder = ContinuousValueEncoder( |
|
|
d_model=config.d_model, |
|
|
dropout=config.expression_encoder_config.get("dropout", 0.1), |
|
|
max_value=config.expression_encoder_config.get("max_value", 512), |
|
|
activation=config.expression_encoder_config.get("activation", "relu"), |
|
|
use_norm=config.expression_encoder_config.get("use_norm", False), |
|
|
) |
|
|
|
|
|
encoder_layer = TXBlock( |
|
|
d_model=config.d_model, |
|
|
n_heads=config.n_heads, |
|
|
expansion_ratio=config.expansion_ratio, |
|
|
attn_config=config.attn_config, |
|
|
norm_config=config.norm_config, |
|
|
activation=config.transformer_activation, |
|
|
norm_scheme=config.norm_scheme, |
|
|
use_glu=config.use_glu, |
|
|
) |
|
|
|
|
|
self.transformer_encoder = TXEncoder( |
|
|
encoder_layer, |
|
|
config.n_layers, |
|
|
use_norm=config.norm_scheme == "pre", |
|
|
norm_config=config.norm_config, |
|
|
) |
|
|
|
|
|
self.expression_decoder = ExprDecoder( |
|
|
d_model=config.d_model, |
|
|
n_outputs=config.expression_decoder_config.get("n_outputs", 1), |
|
|
n_layers=config.expression_decoder_config.get("n_layers", 2), |
|
|
activation=config.expression_decoder_config.get("activation", "leaky_relu"), |
|
|
) |
|
|
|
|
|
if config.mvc_config is not None: |
|
|
self.mvc_decoder = MVCDecoder( |
|
|
d_model=config.d_model, |
|
|
arch_style=config.mvc_config.get("arch_style", "inner product"), |
|
|
query_activation=config.mvc_config.get("query_activation", "sigmoid"), |
|
|
scaled_dot_product=config.mvc_config.get("scaled_dot_product", False), |
|
|
) |
|
|
else: |
|
|
self.mvc_decoder = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
genes: Tensor, |
|
|
values: Tensor, |
|
|
gen_masks: Tensor, |
|
|
key_padding_mask: Tensor, |
|
|
skip_decoders: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
) -> dict: |
|
|
|
|
|
token_embs = self.gene_encoder(genes) |
|
|
token_values = self.expression_encoder(values) |
|
|
token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0) |
|
|
|
|
|
flag = self.flag_encoder(torch.tensor(1, device=token_embs.device)).reshape(1, 1, -1) |
|
|
flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag |
|
|
|
|
|
total_embs = token_embs + token_values + flag_embs |
|
|
|
|
|
self.cur_gene_token_embs = token_embs |
|
|
|
|
|
|
|
|
transformer_output, hidden_states = self.transformer_encoder( |
|
|
total_embs=total_embs, |
|
|
key_padding_mask=key_padding_mask, |
|
|
output_hidden_states=output_hidden_states, |
|
|
) |
|
|
|
|
|
|
|
|
cell_emb = transformer_output[:, 0, :] |
|
|
|
|
|
output = { |
|
|
"transformer_output": transformer_output, |
|
|
"cell_emb": cell_emb, |
|
|
} |
|
|
|
|
|
if output_hidden_states: |
|
|
output["hidden_states"] = hidden_states |
|
|
|
|
|
if skip_decoders: |
|
|
return output |
|
|
|
|
|
|
|
|
expr_output = self.expression_decoder(transformer_output) |
|
|
output["expr_preds"] = expr_output["pred"] |
|
|
|
|
|
if self.mvc_decoder is not None: |
|
|
mvc_output = self.mvc_decoder(cell_emb, self.cur_gene_token_embs) |
|
|
output["mvc_output"] = mvc_output["pred"] |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TXPreTrainedModel(PreTrainedModel): |
|
|
"""Base class for TXModel""" |
|
|
config_class = TXConfig |
|
|
base_model_prefix = "tx_model" |
|
|
supports_gradient_checkpointing = False |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
|
|
|
class TXModelForHF(TXPreTrainedModel): |
|
|
""" |
|
|
HuggingFace-compatible TXModel |
|
|
|
|
|
Requires ONLY: transformers, torch, safetensors |
|
|
""" |
|
|
|
|
|
def __init__(self, config: TXConfig): |
|
|
super().__init__(config) |
|
|
self.tx_model = TXModel(config) |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
genes: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
gen_masks: torch.Tensor, |
|
|
key_padding_mask: Optional[torch.Tensor] = None, |
|
|
skip_decoders: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
return_dict: bool = True, |
|
|
**kwargs |
|
|
) -> Union[Tuple, BaseModelOutput]: |
|
|
|
|
|
if key_padding_mask is None: |
|
|
key_padding_mask = ~genes.eq(self.config.pad_token_id) |
|
|
|
|
|
outputs = self.tx_model( |
|
|
genes=genes, |
|
|
values=values, |
|
|
gen_masks=gen_masks, |
|
|
key_padding_mask=key_padding_mask, |
|
|
skip_decoders=skip_decoders, |
|
|
output_hidden_states=output_hidden_states, |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in outputs.values()) |
|
|
|
|
|
return BaseModelOutput( |
|
|
last_hidden_state=outputs.get("cell_emb"), |
|
|
hidden_states=outputs.get("hidden_states") if output_hidden_states else None, |
|
|
) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.tx_model.gene_encoder.embedding |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.tx_model.gene_encoder.embedding = value |
|
|
|
|
|
|
|
|
|
|
|
TXForCausalLM = TXModelForHF |
|
|
AutoModelForCausalLM = TXModelForHF |
|
|
|