tx-model-standalone / model_standalone.py
Yuto2007's picture
Upload folder using huggingface_hub
e093a4b verified
# Copyright (C) Tahoe Therapeutics 2025. All rights reserved.
"""
Standalone implementation of TXModel without external dependencies.
Only requires: torch, transformers, safetensors
"""
from typing import Optional, Union, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from blocks_standalone import (
ChemEncoder,
ContinuousValueEncoder,
ExprDecoder,
GeneEncoder,
MVCDecoder,
TXBlock,
TXEncoder,
)
class TXModel(nn.Module):
"""Standalone Transformer model for genomic data"""
def __init__(
self,
vocab_size: int,
d_model: int,
n_layers: int,
n_heads: int,
expansion_ratio: int,
pad_token_id: int,
pad_value: float,
num_bins: int,
norm_scheme: str = "pre",
transformer_activation: str = "gelu",
cell_emb_style: str = "cls",
use_chem_token: bool = False,
attn_config: Optional[dict] = None,
norm_config: Optional[dict] = None,
gene_encoder_config: Optional[dict] = None,
expression_encoder_config: Optional[dict] = None,
expression_decoder_config: Optional[dict] = None,
mvc_config: Optional[dict] = None,
chemical_encoder_config: Optional[dict] = None,
use_glu: bool = False,
return_gene_embeddings: bool = False,
keep_first_n_tokens: int = 1,
device: Optional[str] = None,
):
super().__init__()
self.model_type = "Transformer"
self.device = device
self.vocab_size = vocab_size
self.n_layers = n_layers
self.n_heads = n_heads
self.d_model = d_model
self.expansion_ratio = expansion_ratio
self.norm_scheme = norm_scheme
self.transformer_activation = transformer_activation
self.use_chem_token = use_chem_token
self.cell_emb_style = cell_emb_style
self.pad_token_id = pad_token_id
self.pad_value = pad_value
self.n_input_bins = num_bins
self.keep_first_n_tokens = keep_first_n_tokens
self.return_gene_embeddings = return_gene_embeddings
if attn_config is None:
attn_config = {}
if norm_config is None:
norm_config = {}
if gene_encoder_config is None:
gene_encoder_config = {"use_norm": False}
if expression_encoder_config is None:
expression_encoder_config = {}
if expression_decoder_config is None:
expression_decoder_config = {}
# Gene encoder
self.gene_encoder = GeneEncoder(
self.vocab_size,
self.d_model,
padding_idx=self.pad_token_id,
use_norm=gene_encoder_config.get("use_norm", False),
gene_encoder_cfg=gene_encoder_config,
)
# Flag encoder
self.flag_encoder = nn.Embedding(2, self.d_model)
# Expression encoder
self.expression_encoder = ContinuousValueEncoder(
d_model=self.d_model,
dropout=expression_encoder_config.get("dropout", 0.1),
max_value=expression_encoder_config.get("max_value", 512),
activation=expression_encoder_config.get("activation", "relu"),
use_norm=expression_encoder_config.get("use_norm", False),
)
# Chemical encoder (if needed)
if self.use_chem_token:
if chemical_encoder_config is None:
chemical_encoder_config = {}
self.chem_encoder = ChemEncoder(
d_out=self.d_model,
padding_idx=chemical_encoder_config.get("padding_idx", 0),
activation=chemical_encoder_config.get("activation", "leaky_relu"),
freeze=chemical_encoder_config.get("freeze", False),
num_drugs=chemical_encoder_config.get("num_drugs", 1000),
fp_dim=chemical_encoder_config.get("fp_dim", 2048),
)
# Transformer encoder
encoder_layer = TXBlock(
d_model=self.d_model,
n_heads=self.n_heads,
expansion_ratio=self.expansion_ratio,
attn_config=attn_config,
norm_config=norm_config,
activation=self.transformer_activation,
device=self.device,
norm_scheme=self.norm_scheme,
use_glu=use_glu,
)
self.transformer_encoder = TXEncoder(
encoder_layer,
self.n_layers,
use_norm=self.norm_scheme == "pre",
norm_config=norm_config,
attn_config=attn_config,
)
# Expression decoder
self.expression_decoder = ExprDecoder(
d_model=self.d_model,
n_outputs=expression_decoder_config.get("n_outputs", 1),
n_layers=expression_decoder_config.get("n_layers", 2),
activation=expression_decoder_config.get("activation", "leaky_relu"),
)
# MVC decoder (if configured)
if mvc_config is not None:
self.mvc_decoder = MVCDecoder(
d_model=self.d_model,
arch_style=mvc_config.get("arch_style", "inner product"),
query_activation=mvc_config.get("query_activation", "sigmoid"),
scaled_dot_product=mvc_config.get("scaled_dot_product", False),
)
else:
self.mvc_decoder = None
def transformer_generate(
self,
genes: Tensor,
values: Tensor,
gen_masks: Tensor,
key_padding_mask: Tensor,
drug_ids: Optional[Tensor] = None,
output_hidden_states: bool = False,
) -> Union[Tensor, Tuple[Tensor, list]]:
# Encode genes
token_embs = self.gene_encoder(genes)
# Encode expression values
token_values = self.expression_encoder(values)
token_values = token_values.masked_fill(gen_masks.unsqueeze(-1), 0.0)
# Flag embeddings
flag = self.flag_encoder(
torch.tensor(1, device=token_embs.device)
).reshape(1, 1, -1)
flag_embs = gen_masks.unsqueeze(-1).to(token_embs.dtype) * flag
# Combine embeddings
total_embs = token_embs + token_values + flag_embs
# Add chemical embedding if used
if self.use_chem_token and drug_ids is not None:
drug_embs = self.chem_encoder(drug_ids)
total_embs[:, 1, :] = drug_embs
# Store gene embeddings for MVC
self.cur_gene_token_embs = token_embs
# Pass through transformer
output, hidden_states = self.transformer_encoder(
total_embs=total_embs,
key_padding_mask=key_padding_mask,
output_hidden_states=output_hidden_states,
)
return output, hidden_states
def forward(
self,
genes: Tensor,
values: Tensor,
gen_masks: Tensor,
key_padding_mask: Tensor,
drug_ids: Optional[Tensor] = None,
skip_decoders: bool = False,
output_hidden_states: bool = False,
) -> dict:
# Generate transformer output
transformer_output, hidden_states = self.transformer_generate(
genes, values, gen_masks, key_padding_mask,
drug_ids, output_hidden_states
)
# Prepare output dict
output = {
"transformer_output": transformer_output,
}
if output_hidden_states:
output["hidden_states"] = hidden_states
# Cell embedding (CLS token or pooling)
if self.cell_emb_style == "cls":
cell_emb = transformer_output[:, 0, :]
elif self.cell_emb_style == "avg-pool":
# Average over non-padding tokens
mask = key_padding_mask.unsqueeze(-1).float()
cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
elif self.cell_emb_style == "w-pool":
# Weighted pooling (not implemented, use avg)
mask = key_padding_mask.unsqueeze(-1).float()
cell_emb = (transformer_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
else:
cell_emb = transformer_output[:, 0, :]
output["cell_emb"] = cell_emb
# Return gene embeddings if requested
if self.return_gene_embeddings:
output["gene_embeddings"] = transformer_output
# Skip decoders if requested
if skip_decoders:
return output
# Expression decoder
expr_output = self.expression_decoder(transformer_output)
output["expr_preds"] = expr_output["pred"]
# MVC decoder (if available)
if self.mvc_decoder is not None:
mvc_output = self.mvc_decoder(
cell_emb,
self.cur_gene_token_embs,
)
output["mvc_output"] = mvc_output["pred"]
return output
@classmethod
def from_pretrained(cls, model_path: str, **kwargs):
"""Load model from pretrained weights"""
from safetensors.torch import load_file
import json
from pathlib import Path
model_path = Path(model_path)
# Load config
with open(model_path / "config.json", "r") as f:
config = json.load(f)
# Create model
model = cls(
vocab_size=config["vocab_size"],
d_model=config["d_model"],
n_layers=config["n_layers"],
n_heads=config["n_heads"],
expansion_ratio=config["expansion_ratio"],
pad_token_id=config["pad_token_id"],
pad_value=config["pad_value"],
num_bins=config["num_bins"],
norm_scheme=config.get("norm_scheme", "pre"),
transformer_activation=config.get("transformer_activation", "gelu"),
cell_emb_style=config.get("cell_emb_style", "cls"),
use_chem_token=config.get("use_chem_token", False),
attn_config=config.get("attn_config"),
norm_config=config.get("norm_config"),
gene_encoder_config=config.get("gene_encoder_config"),
expression_encoder_config=config.get("expression_encoder_config"),
expression_decoder_config=config.get("expression_decoder_config"),
mvc_config=config.get("mvc_config"),
chemical_encoder_config=config.get("chemical_encoder_config"),
use_glu=config.get("use_glu", False),
return_gene_embeddings=config.get("return_gene_embeddings", False),
keep_first_n_tokens=config.get("keep_first_n_tokens", 1),
)
# Load weights
state_dict = load_file(model_path / "model.safetensors")
# Remove 'model.tx_model.' or 'tx_model.' prefix if present
new_state_dict = {}
for k, v in state_dict.items():
new_key = k
if k.startswith('model.tx_model.'):
new_key = k[14:] # Remove 'model.tx_model.'
elif k.startswith('tx_model.'):
new_key = k[9:] # Remove 'tx_model.'
elif k.startswith('model.'):
new_key = k[6:] # Remove 'model.'
new_state_dict[new_key] = v
model.load_state_dict(new_state_dict, strict=False)
return model