|
|
|
|
|
""" |
|
|
Standalone implementation of TXModel blocks without external dependencies. |
|
|
Only requires: torch, transformers |
|
|
""" |
|
|
|
|
|
import math |
|
|
from typing import Optional, Dict, Any, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor, nn |
|
|
|
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
|
"""Standard multi-head attention implementation""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
kv_n_heads: Optional[int] = None, |
|
|
dropout: float = 0.0, |
|
|
bias: bool = True, |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.kv_n_heads = kv_n_heads if kv_n_heads is not None else n_heads |
|
|
self.head_dim = d_model // n_heads |
|
|
self.dropout = dropout |
|
|
|
|
|
|
|
|
self.n_rep = n_heads // self.kv_n_heads |
|
|
|
|
|
self.q_proj = nn.Linear(d_model, d_model, bias=bias, device=device) |
|
|
self.k_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, bias=bias, device=device) |
|
|
self.v_proj = nn.Linear(d_model, self.kv_n_heads * self.head_dim, bias=bias, device=device) |
|
|
self.out_proj = nn.Linear(d_model, d_model, bias=bias, device=device) |
|
|
|
|
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: Tensor, |
|
|
attn_bias: Optional[Tensor] = None, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
is_causal: bool = False, |
|
|
**kwargs |
|
|
) -> Tuple[Tensor, None, None]: |
|
|
batch_size, seq_len, _ = x.shape |
|
|
|
|
|
|
|
|
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.head_dim) |
|
|
k = self.k_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim) |
|
|
v = self.v_proj(x).view(batch_size, seq_len, self.kv_n_heads, self.head_dim) |
|
|
|
|
|
|
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.n_rep > 1: |
|
|
k = k.repeat_interleave(self.n_rep, dim=1) |
|
|
v = v.repeat_interleave(self.n_rep, dim=1) |
|
|
|
|
|
|
|
|
scale = 1.0 / math.sqrt(self.head_dim) |
|
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale |
|
|
|
|
|
|
|
|
if attn_bias is not None: |
|
|
attn_scores = attn_scores + attn_bias |
|
|
|
|
|
|
|
|
if key_padding_mask is not None: |
|
|
|
|
|
|
|
|
mask = key_padding_mask.unsqueeze(1).unsqueeze(2) |
|
|
attn_scores = attn_scores.masked_fill(~mask, float('-inf')) |
|
|
|
|
|
|
|
|
if is_causal: |
|
|
causal_mask = torch.triu( |
|
|
torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool), |
|
|
diagonal=1 |
|
|
) |
|
|
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf')) |
|
|
|
|
|
|
|
|
attn_weights = F.softmax(attn_scores, dim=-1) |
|
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
|
|
|
|
|
|
output = torch.matmul(attn_weights, v) |
|
|
|
|
|
|
|
|
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) |
|
|
output = self.out_proj(output) |
|
|
|
|
|
return output, None, None |
|
|
|
|
|
|
|
|
class TXBlock(nn.Module): |
|
|
"""Transformer encoder block with pre/post normalization support""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
expansion_ratio: int, |
|
|
attn_config: Optional[Dict] = None, |
|
|
norm_config: Optional[Dict] = None, |
|
|
dropout: Optional[float] = 0.0, |
|
|
activation: Optional[str] = "gelu", |
|
|
device: Optional[str] = None, |
|
|
norm_scheme: str = "pre", |
|
|
use_glu: bool = False, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
if attn_config is None: |
|
|
attn_config = {} |
|
|
if norm_config is None: |
|
|
norm_config = {} |
|
|
|
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.device = device |
|
|
self.norm_scheme = norm_scheme |
|
|
self.use_glu = use_glu |
|
|
|
|
|
|
|
|
kv_n_heads = attn_config.get("kv_n_heads", n_heads) |
|
|
self.self_attn = MultiheadAttention( |
|
|
d_model=d_model, |
|
|
n_heads=n_heads, |
|
|
kv_n_heads=kv_n_heads, |
|
|
dropout=attn_config.get("attn_pdrop", 0.0), |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
dim_feedforward = d_model * expansion_ratio |
|
|
self.up_proj = nn.Linear(d_model, dim_feedforward, device=device) |
|
|
self.down_proj = nn.Linear(dim_feedforward, d_model, device=device) |
|
|
|
|
|
if use_glu: |
|
|
self.gate_proj = nn.Linear(d_model, dim_feedforward, device=device) |
|
|
|
|
|
|
|
|
eps = norm_config.get("eps", 1e-5) |
|
|
self.norm1 = nn.LayerNorm(d_model, eps=eps, device=device) |
|
|
self.norm2 = nn.LayerNorm(d_model, eps=eps, device=device) |
|
|
|
|
|
|
|
|
self.post_sa_dropout = nn.Dropout(dropout) |
|
|
self.post_ffn_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.activation = self._get_activation_fn(activation) |
|
|
|
|
|
if norm_scheme not in ["pre", "post"]: |
|
|
raise ValueError("norm_scheme must be either pre or post") |
|
|
|
|
|
@staticmethod |
|
|
def _get_activation_fn(activation: str): |
|
|
if activation == "gelu": |
|
|
return nn.GELU() |
|
|
elif activation == "relu": |
|
|
return nn.ReLU() |
|
|
elif activation == "silu" or activation == "swish": |
|
|
return nn.SiLU() |
|
|
elif activation == "leaky_relu": |
|
|
return nn.LeakyReLU() |
|
|
else: |
|
|
raise ValueError(f"Unknown activation: {activation}") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: Tensor, |
|
|
attn_bias: Optional[Tensor] = None, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
**kwargs |
|
|
) -> Tensor: |
|
|
|
|
|
if self.norm_scheme == "pre": |
|
|
|
|
|
x = x + self._sa_block( |
|
|
self.norm1(x), |
|
|
attn_bias=attn_bias, |
|
|
key_padding_mask=key_padding_mask, |
|
|
) |
|
|
x = x + self._ff_block(self.norm2(x)) |
|
|
else: |
|
|
|
|
|
x = self.norm1( |
|
|
x + self._sa_block( |
|
|
x, |
|
|
attn_bias=attn_bias, |
|
|
key_padding_mask=key_padding_mask, |
|
|
) |
|
|
) |
|
|
x = self.norm2(x + self._ff_block(x)) |
|
|
|
|
|
return x |
|
|
|
|
|
def _sa_block( |
|
|
self, |
|
|
x: Tensor, |
|
|
attn_bias: Optional[Tensor] = None, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
) -> Tensor: |
|
|
x, _, _ = self.self_attn( |
|
|
x, |
|
|
attn_bias=attn_bias, |
|
|
key_padding_mask=key_padding_mask, |
|
|
is_causal=False, |
|
|
) |
|
|
return self.post_sa_dropout(x) |
|
|
|
|
|
def _ff_block(self, x: Tensor) -> Tensor: |
|
|
if self.use_glu: |
|
|
|
|
|
x = self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x)) |
|
|
else: |
|
|
|
|
|
x = self.down_proj(self.activation(self.up_proj(x))) |
|
|
return self.post_ffn_dropout(x) |
|
|
|
|
|
|
|
|
class TXEncoder(nn.Module): |
|
|
"""Stack of transformer encoder layers""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
encoder_layer: TXBlock, |
|
|
num_layers: int, |
|
|
use_norm: bool = False, |
|
|
norm_config: Optional[Dict] = None, |
|
|
attn_config: Optional[Dict] = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if norm_config is None: |
|
|
norm_config = {} |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
TXBlock( |
|
|
d_model=encoder_layer.d_model, |
|
|
n_heads=encoder_layer.n_heads, |
|
|
expansion_ratio=encoder_layer.up_proj.out_features // encoder_layer.d_model, |
|
|
attn_config=attn_config, |
|
|
norm_config=norm_config, |
|
|
activation="gelu", |
|
|
device=encoder_layer.device, |
|
|
norm_scheme=encoder_layer.norm_scheme, |
|
|
use_glu=encoder_layer.use_glu, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
]) |
|
|
|
|
|
self.use_norm = use_norm |
|
|
if use_norm: |
|
|
eps = norm_config.get("eps", 1e-5) |
|
|
self.norm = nn.LayerNorm(encoder_layer.d_model, eps=eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
total_embs: Tensor, |
|
|
key_padding_mask: Optional[Tensor] = None, |
|
|
output_hidden_states: bool = False, |
|
|
) -> Tuple[Tensor, Optional[list]]: |
|
|
|
|
|
x = total_embs |
|
|
hidden_states = [] if output_hidden_states else None |
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer( |
|
|
x, |
|
|
attn_bias=None, |
|
|
key_padding_mask=key_padding_mask, |
|
|
) |
|
|
|
|
|
if output_hidden_states: |
|
|
hidden_states.append(x) |
|
|
|
|
|
if self.use_norm: |
|
|
x = self.norm(x) |
|
|
|
|
|
return x, hidden_states |
|
|
|
|
|
|
|
|
class GeneEncoder(nn.Module): |
|
|
"""Gene embedding with optional extra embeddings""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_embeddings: int, |
|
|
embedding_dim: int, |
|
|
padding_idx: int = 0, |
|
|
use_norm: bool = False, |
|
|
gene_encoder_cfg: Optional[Dict] = None, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if gene_encoder_cfg is None: |
|
|
gene_encoder_cfg = {} |
|
|
|
|
|
self.use_norm = use_norm |
|
|
self.embedding = nn.Embedding( |
|
|
num_embeddings, |
|
|
embedding_dim, |
|
|
padding_idx=padding_idx, |
|
|
) |
|
|
|
|
|
|
|
|
self.project = nn.Identity() |
|
|
|
|
|
if self.use_norm: |
|
|
self.enc_norm = nn.LayerNorm(embedding_dim) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.embedding(x) |
|
|
x = self.project(x) |
|
|
if self.use_norm: |
|
|
x = self.enc_norm(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ChemEncoder(nn.Module): |
|
|
"""Chemical compound encoder""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_out: int, |
|
|
padding_idx: int = 0, |
|
|
activation: str = "leaky_relu", |
|
|
use_norm: bool = True, |
|
|
freeze: bool = False, |
|
|
num_drugs: int = 1000, |
|
|
fp_dim: int = 2048, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
drug_fps = torch.zeros((num_drugs, fp_dim), dtype=torch.float32) |
|
|
|
|
|
self.embedding = nn.Embedding.from_pretrained( |
|
|
drug_fps, |
|
|
padding_idx=padding_idx, |
|
|
freeze=freeze, |
|
|
) |
|
|
|
|
|
self.fc = nn.Linear(fp_dim, d_out) |
|
|
|
|
|
if activation == "leaky_relu": |
|
|
self.activation = nn.LeakyReLU() |
|
|
elif activation == "relu": |
|
|
self.activation = nn.ReLU() |
|
|
elif activation == "gelu": |
|
|
self.activation = nn.GELU() |
|
|
else: |
|
|
self.activation = nn.ReLU() |
|
|
|
|
|
self.proj = nn.Linear(d_out, d_out) |
|
|
|
|
|
self.use_norm = use_norm |
|
|
if self.use_norm: |
|
|
self.norm = nn.LayerNorm(d_out) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.embedding(x) |
|
|
x = self.activation(self.fc(x)) |
|
|
x = self.proj(x) |
|
|
|
|
|
if self.use_norm: |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ContinuousValueEncoder(nn.Module): |
|
|
"""Encode continuous values to embeddings""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
dropout: float = 0.1, |
|
|
max_value: int = 512, |
|
|
activation: str = "relu", |
|
|
use_norm: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
self.linear1 = nn.Linear(1, d_model) |
|
|
|
|
|
if activation == "relu": |
|
|
self.activation = nn.ReLU() |
|
|
elif activation == "gelu": |
|
|
self.activation = nn.GELU() |
|
|
elif activation == "leaky_relu": |
|
|
self.activation = nn.LeakyReLU() |
|
|
else: |
|
|
self.activation = nn.ReLU() |
|
|
|
|
|
self.linear2 = nn.Linear(d_model, d_model) |
|
|
|
|
|
self.use_norm = use_norm |
|
|
if self.use_norm: |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
|
|
|
self.max_value = max_value |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
|
|
|
x = x.unsqueeze(-1) |
|
|
|
|
|
x = torch.clamp(x, max=self.max_value) |
|
|
|
|
|
x = self.activation(self.linear1(x)) |
|
|
x = self.linear2(x) |
|
|
if self.use_norm: |
|
|
x = self.norm(x) |
|
|
return self.dropout(x) |
|
|
|
|
|
|
|
|
class ExprDecoder(nn.Module): |
|
|
"""Expression value decoder""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_outputs: int = 1, |
|
|
n_layers: int = 2, |
|
|
activation: str = "leaky_relu", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if activation == "leaky_relu": |
|
|
self.activation = nn.LeakyReLU() |
|
|
elif activation == "relu": |
|
|
self.activation = nn.ReLU() |
|
|
elif activation == "gelu": |
|
|
self.activation = nn.GELU() |
|
|
else: |
|
|
self.activation = nn.LeakyReLU() |
|
|
|
|
|
self.linear_layers = nn.ModuleList( |
|
|
[nn.Linear(d_model, d_model) for _ in range(n_layers)] |
|
|
) |
|
|
self.out_proj = nn.Linear(d_model, n_outputs) |
|
|
|
|
|
def forward(self, x: Tensor) -> Dict[str, Tensor]: |
|
|
for layer in self.linear_layers: |
|
|
x = self.activation(layer(x)) |
|
|
pred_value = self.out_proj(x) |
|
|
if pred_value.shape[-1] == 1: |
|
|
pred_value = pred_value.squeeze(-1) |
|
|
return {"pred": pred_value} |
|
|
|
|
|
|
|
|
class MVCDecoder(nn.Module): |
|
|
"""Masked value prediction decoder""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
arch_style: str = "inner product", |
|
|
query_activation: str = "sigmoid", |
|
|
scaled_dot_product: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.scaled_dot_product = scaled_dot_product |
|
|
|
|
|
if arch_style == "inner product": |
|
|
self.gene2query = nn.Linear(d_model, d_model) |
|
|
|
|
|
if query_activation == "sigmoid": |
|
|
self.query_activation = nn.Sigmoid() |
|
|
elif query_activation == "relu": |
|
|
self.query_activation = nn.ReLU() |
|
|
elif query_activation == "tanh": |
|
|
self.query_activation = nn.Tanh() |
|
|
else: |
|
|
self.query_activation = nn.Sigmoid() |
|
|
|
|
|
self.W = nn.Linear(d_model, d_model, bias=False) |
|
|
else: |
|
|
raise ValueError(f"Unknown arch_style: {arch_style}") |
|
|
|
|
|
self.arch_style = arch_style |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
cell_emb: Tensor, |
|
|
gene_embs: Tensor, |
|
|
) -> Dict[str, Tensor]: |
|
|
|
|
|
if self.arch_style == "inner product": |
|
|
query_vecs = self.query_activation( |
|
|
self.gene2query(gene_embs) |
|
|
) |
|
|
inner_product_dimension = query_vecs.shape[-1] |
|
|
cell_emb = cell_emb.unsqueeze(2) |
|
|
pred_value = torch.bmm(self.W(query_vecs), cell_emb).squeeze(2) |
|
|
|
|
|
if self.scaled_dot_product: |
|
|
pred_value = pred_value / torch.sqrt( |
|
|
torch.tensor(inner_product_dimension, dtype=pred_value.dtype) |
|
|
) |
|
|
|
|
|
return {"pred": pred_value} |
|
|
else: |
|
|
raise ValueError(f"Unknown arch_style: {self.arch_style}") |
|
|
|