|
|
|
|
|
""" |
|
|
Standalone implementation of TXModel without external dependencies. |
|
|
Only requires: torch, transformers, safetensors |
|
|
""" |
|
|
|
|
|
from typing import Optional, Union, Tuple |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor, nn |
|
|
|
|
|
from blocks_standalone import ( |
|
|
ChemEncoder, |
|
|
ContinuousValueEncoder, |
|
|
ExprDecoder, |
|
|
GeneEncoder, |
|
|
MVCDecoder, |
|
|
TXBlock, |
|
|
TXEncoder, |
|
|
) |
|
|
|
|
|
|
|
|
class TXModel(nn.Module): |
|
|
"""Standalone Transformer model for genomic data""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size: int, |
|
|
d_model: int, |
|
|
n_layers: int, |
|
|
n_heads: int, |
|
|
expansion_ratio: int, |
|
|
pad_token_id: int, |
|
|
pad_value: float, |
|
|
num_bins: int, |
|
|
norm_scheme: str = "pre", |
|
|
transformer_activation: str = "gelu", |
|
|
cell_emb_style: str = "cls", |
|
|
use_chem_token: bool = False, |
|
|
attn_config: Optional[dict] = None, |
|
|
norm_config: Optional[dict] = None, |
|
|
gene_encoder_config: Optional[dict] = None, |
|
|
expression_encoder_config: Optional[dict] = None, |
|
|
expression_decoder_config: Optional[dict] = None, |
|
|
mvc_config: Optional[dict] = None, |
|
|
chemical_encoder_config: Optional[dict] = None, |
|
|
use_glu: bool = False, |
|
|
return_gene_embeddings: bool = False, |
|
|
keep_first_n_tokens: int = 1, |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.model_type = "Transformer" |
|
|
self.device = device |
|
|
self.vocab_size = vocab_size |
|
|
self.n_layers = n_layers |
|
|
self.n_heads = n_heads |
|
|
self.d_model = d_model |
|
|
self.expansion_ratio = expansion_ratio |
|
|
self.norm_scheme = norm_scheme |
|
|
self.transformer_activation = transformer_activation |
|
|
self.use_chem_token = use_chem_token |
|
|
self.cell_emb_style = cell_emb_style |
|
|
self.pad_token_id = pad_token_id |
|
|
self.pad_value = pad_value |
|
|
self.n_input_bins = num_bins |
|
|
self.keep_first_n_tokens = keep_first_n_tokens |
|
|
self.return_gene_embeddings = return_gene_embeddings |
|
|
|
|
|
if attn_config is None: |
|
|
attn_config = {} |
|
|
if norm_config is None: |
|
|
norm_config = {} |
|
|
if gene_encoder_config is None: |
|
|
gene_encoder_config = {"use_norm": False} |
|
|
if expression_encoder_config is None: |
|
|
expression_encoder_config = {} |
|
|
if expression_decoder_config is None: |
|
|
expression_decoder_config = {} |
|
|
|
|
|
|
|
|
self.gene_encoder = GeneEncoder( |
|
|
self.vocab_size, |
|
|
self.d_model, |
|
|
padding_idx=self.pad_token_id, |
|
|
use_norm=gene_encoder_config.get("use_norm", False), |
|
|
gene_encoder_cfg=gene_encoder_config, |
|
|
) |
|
|
|
|
|
|
|
|
self.flag_encoder = nn.Embedding(2, self.d_model) |
|
|
|
|
|
|
|
|
self.expression_encoder = ContinuousValueEncoder( |
|
|
d_model=self.d_model, |
|
|
dropout=expression_encoder_config.get("dropout", 0.1), |
|
|
max_value=expression_encoder_config.get("max_value", 512), |
|
|
activation=expression_encoder_config.get("activation", "relu"), |
|
|
use_norm=expression_encoder_config.get("use_norm", False), |
|
|
) |
|
|
|
|
|
|
|
|
if self.use_chem_token: |
|
|
if chemical_encoder_config is None: |
|
|
chemical_encoder_config = {} |
|
|
self.chem_encoder = ChemEncoder( |
|
|
d_out=self.d_model, |
|
|
padding_idx=chemical_encoder_config.get("padding_idx", 0), |
|
|
activation=chemical_encoder_config.get("activation", "leaky_relu"), |
|
|
freeze=chemical_encoder_config.get("freeze", False), |
|
|
num_drugs=chemical_encoder_config.get("num_drugs", 1000), |
|
|
fp_dim=chemical_encoder_config.get("fp_dim", 2048), |
|
|
) |
|
|
|
|
|
|
|
|
encoder_layer = TXBlock( |
|
|
d_model=self.d_model, |
|
|
n_heads=self.n_heads, |
|
|
expansion_ratio=self.expansion_ratio, |
|
|
attn_config=attn_config, |
|
|
norm_config=norm_config, |
|
|
activation=self.transformer_activation, |
|
|
device=self.device, |
|
|
norm_scheme=self.norm_scheme, |
|
|
use_glu=use_glu, |
|
|
) |
|
|
|
|
|
self.transformer_encoder = TXEncoder( |
|
|
encoder_layer, |
|
|
self.n_layers, |
|
|
use_norm=self.norm_scheme == "pre", |
|
|
norm_config=norm_config, |
|
|
attn_config=attn_config, |
|
|
) |
|
|
|
|
|
|
|
|
self.expression_decoder = ExprDecoder( |
|
|
d_model=self.d_model, |
|
|
n_outputs=expression_decoder_config.get("n_outputs", 1), |
|
|
n_layers=expression_decoder_config.get("n_layers", 2), |
|
|
activation=expression_decoder_config.get("activation", "leaky_relu"), |
|
|
) |
|
|
|
|
|
|
|
|
if mvc_config is not None: |
|
|
self.mvc_decoder = MVCDecoder( |
|
|
d_model=self.d_model, |
|
|
arch_style=mvc_config.get("arch_style", "inner product"), |
|
|
query_activation=mvc_config.get("query_activation", "sigmoid"), |
|
|
scaled_dot_product=mvc_config.get("scaled_dot_product", False), |
|
|
) |
|
|
else: |
|
|
self.mvc_decoder = None |
|
|
|
|
|
def transformer_generate( |
|
|
self, |
|
|
genes: Tensor, |
|
|
values: Tensor, |
|
|
gen_masks: Tensor, |
|
|
key_padding_mask: Tensor, |
|
|
drug_ids: Optional[Tensor] = None, |
|
|
output_hidden_states: bool = False, |
|
|
) -> Union[Tensor, Tuple[Tensor, list]]: |
|
|
|
|
|
|
|
|
token_embs = self.gene_encoder(genes) |
|
|
|
|
|
|
|
|
token_values = self.expression_encoder(values) |
|
|
token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0) |
|
|
|
|
|
|
|
|
flag = self.flag_encoder( |
|
|
torch.tensor(1, device=token_embs.device) |
|
|
).reshape(1, 1, -1) |
|
|
flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag |
|
|
|
|
|
|
|
|
total_embs = token_embs + token_values + flag_embs |
|
|
|
|
|
|
|
|
if self.use_chem_token and drug_ids is not None: |
|
|
drug_embs = self.chem_encoder(drug_ids) |
|
|
total_embs[:, 1, :] = drug_embs |
|
|
|
|
|
|
|
|
self.cur_gene_token_embs = token_embs |
|
|
|
|
|
|
|
|
output, hidden_states = self.transformer_encoder( |
|
|
total_embs=total_embs, |
|
|
key_padding_mask=key_padding_mask, |
|
|
output_hidden_states=output_hidden_states, |
|
|
) |
|
|
|
|
|
return output, hidden_states |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
genes: Tensor, |
|
|
values: Tensor, |
|
|
gen_masks: Tensor, |
|
|
key_padding_mask: Tensor, |
|
|
drug_ids: Optional[Tensor] = None, |
|
|
skip_decoders: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
) -> dict: |
|
|
|
|
|
|
|
|
transformer_output, hidden_states = self.transformer_generate( |
|
|
genes, values, gen_masks, key_padding_mask, |
|
|
drug_ids, output_hidden_states |
|
|
) |
|
|
|
|
|
|
|
|
output = { |
|
|
"transformer_output": transformer_output, |
|
|
} |
|
|
|
|
|
if output_hidden_states: |
|
|
output["hidden_states"] = hidden_states |
|
|
|
|
|
|
|
|
if self.cell_emb_style == "cls": |
|
|
cell_emb = transformer_output[:, 0, :] |
|
|
elif self.cell_emb_style == "avg-pool": |
|
|
|
|
|
mask = key_padding_mask.unsqueeze(-1).float() |
|
|
cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
|
|
elif self.cell_emb_style == "w-pool": |
|
|
|
|
|
mask = key_padding_mask.unsqueeze(-1).float() |
|
|
cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) |
|
|
else: |
|
|
cell_emb = transformer_output[:, 0, :] |
|
|
|
|
|
output["cell_emb"] = cell_emb |
|
|
|
|
|
|
|
|
if self.return_gene_embeddings: |
|
|
output["gene_embeddings"] = transformer_output |
|
|
|
|
|
|
|
|
if skip_decoders: |
|
|
return output |
|
|
|
|
|
|
|
|
expr_output = self.expression_decoder(transformer_output) |
|
|
output["expr_preds"] = expr_output["pred"] |
|
|
|
|
|
|
|
|
if self.mvc_decoder is not None: |
|
|
mvc_output = self.mvc_decoder( |
|
|
cell_emb, |
|
|
self.cur_gene_token_embs, |
|
|
) |
|
|
output["mvc_output"] = mvc_output["pred"] |
|
|
|
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_path: str, **kwargs): |
|
|
"""Load model from pretrained weights""" |
|
|
from safetensors.torch import load_file |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
model_path = Path(model_path) |
|
|
|
|
|
|
|
|
with open(model_path / "config.json", "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
model = cls( |
|
|
vocab_size=config["vocab_size"], |
|
|
d_model=config["d_model"], |
|
|
n_layers=config["n_layers"], |
|
|
n_heads=config["n_heads"], |
|
|
expansion_ratio=config["expansion_ratio"], |
|
|
pad_token_id=config["pad_token_id"], |
|
|
pad_value=config["pad_value"], |
|
|
num_bins=config["num_bins"], |
|
|
norm_scheme=config.get("norm_scheme", "pre"), |
|
|
transformer_activation=config.get("transformer_activation", "gelu"), |
|
|
cell_emb_style=config.get("cell_emb_style", "cls"), |
|
|
use_chem_token=config.get("use_chem_token", False), |
|
|
attn_config=config.get("attn_config"), |
|
|
norm_config=config.get("norm_config"), |
|
|
gene_encoder_config=config.get("gene_encoder_config"), |
|
|
expression_encoder_config=config.get("expression_encoder_config"), |
|
|
expression_decoder_config=config.get("expression_decoder_config"), |
|
|
mvc_config=config.get("mvc_config"), |
|
|
chemical_encoder_config=config.get("chemical_encoder_config"), |
|
|
use_glu=config.get("use_glu", False), |
|
|
return_gene_embeddings=config.get("return_gene_embeddings", False), |
|
|
keep_first_n_tokens=config.get("keep_first_n_tokens", 1), |
|
|
) |
|
|
|
|
|
|
|
|
state_dict = load_file(model_path / "model.safetensors") |
|
|
|
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
new_key = k |
|
|
if k.startswith('model.tx_model.'): |
|
|
new_key = k[14:] |
|
|
elif k.startswith('tx_model.'): |
|
|
new_key = k[9:] |
|
|
elif k.startswith('model.'): |
|
|
new_key = k[6:] |
|
|
new_state_dict[new_key] = v |
|
|
|
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
return model |
|
|
|