|
|
""" |
|
|
Module: model.py |
|
|
|
|
|
This module defines the `TeddyGModel`, a transformer-based architecture designed for single-cell biology tasks. |
|
|
The model is built on top of Hugging Face's `PreTrainedModel` and includes custom configurations, embeddings, |
|
|
and classification heads to handle gene expression data and biological annotations. |
|
|
|
|
|
Main Features: |
|
|
- **TeddyGConfig**: A configuration class for specifying model hyperparameters such as the number of tokens, |
|
|
embedding dimensions, number of layers, and loss weights. |
|
|
- **TeddyGModel**: The main transformer-based model that supports: |
|
|
- Gene token embeddings and position embeddings. |
|
|
- Biological annotation embeddings (e.g., disease, tissue, cell type, sex). |
|
|
- Masked language modeling and annotation classification losses. |
|
|
- Gradient checkpointing for memory efficiency during training. |
|
|
- Customizable classification heads for downstream tasks. |
|
|
- **TeddyGModelAnalysis**: A subclass of `TeddyGModel` with additional functionality for analysis tasks. |
|
|
|
|
|
Dependencies: |
|
|
- PyTorch: For defining and training the model. |
|
|
- Transformers: For leveraging Hugging Face's `PreTrainedModel` and `PretrainedConfig`. |
|
|
- Torch.nn: For building neural network layers and components. |
|
|
|
|
|
Usage: |
|
|
1. Define a `TeddyGConfig` object with the desired hyperparameters. |
|
|
2. Initialize a `TeddyGModel` using the configuration. |
|
|
3. Use the model for tasks such as masked language modeling, annotation classification, or embedding extraction. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
from teddy.models.teddy_g.model import TeddyGConfig, TeddyGModel |
|
|
|
|
|
# Define the configuration |
|
|
config = TeddyGConfig(...) |
|
|
|
|
|
# Initialize the model |
|
|
model = TeddyGModel(config) |
|
|
""" |
|
|
|
|
|
from typing import Mapping, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor, nn |
|
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
from teddy.models.classification_heads import ( |
|
|
ClassificationHead, |
|
|
ClassificationHeadAnalysis, |
|
|
ClsDecoder, |
|
|
) |
|
|
|
|
|
|
|
|
class TeddyGConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
annotation_loss_weight: Optional[float] = None, |
|
|
modeling_loss_weight: Optional[float] = None, |
|
|
ntoken: int = 25472, |
|
|
max_position_embeddings: int = 1500, |
|
|
nlayers: int = 12, |
|
|
nheads: int = 16, |
|
|
d_model: int = 512, |
|
|
d_hid: int = 1024, |
|
|
layer_activation="relu", |
|
|
n_layers_cls: int = 0, |
|
|
n_cls: int = 0, |
|
|
dropout: float = 0.0, |
|
|
initializer_range=0.02, |
|
|
pad_token_id: int = -100, |
|
|
pre_norm: bool = False, |
|
|
cls_loss=False, |
|
|
masking_loss=False, |
|
|
decoding_loss=False, |
|
|
gradient_checkpointing=False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.annotation_loss_weight = annotation_loss_weight |
|
|
self.modeling_loss_weight = modeling_loss_weight |
|
|
self.ntoken = ntoken |
|
|
self.d_model = d_model |
|
|
self.nheads = nheads |
|
|
self.d_hid = d_hid |
|
|
self.nlayers = nlayers |
|
|
self.layer_activation = layer_activation |
|
|
self.n_layers_cls = n_layers_cls |
|
|
self.n_cls = n_cls |
|
|
self.dropout = dropout |
|
|
self.initializer_range = initializer_range |
|
|
self.pad_value = pad_token_id |
|
|
self.pre_norm = pre_norm |
|
|
self.cls_loss = cls_loss |
|
|
self.decoding_loss = decoding_loss |
|
|
self.masking_loss = masking_loss |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.gradient_checkpointing = gradient_checkpointing |
|
|
self.architectures = ["TeddyGModel"] |
|
|
self.model_type = "teddy_g" |
|
|
|
|
|
|
|
|
class TeddyGModel(PreTrainedModel): |
|
|
def __init__( |
|
|
self, |
|
|
config: TeddyGConfig, |
|
|
): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.embeddings = nn.Embedding(config.ntoken, config.d_model) |
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model) |
|
|
encoder_layers = TransformerEncoderLayer( |
|
|
config.d_model, |
|
|
config.nheads, |
|
|
config.d_hid, |
|
|
config.dropout, |
|
|
batch_first=True, |
|
|
norm_first=config.pre_norm, |
|
|
activation=config.layer_activation, |
|
|
) |
|
|
self.encoder = TransformerEncoder(encoder_layers, config.nlayers) |
|
|
self.decoder_head = nn.Linear(config.d_model, config.ntoken, bias=False) |
|
|
self.decoder_bias = nn.Parameter(torch.zeros(config.ntoken)) |
|
|
|
|
|
if config.n_cls > 0: |
|
|
self.add_classification_head(config.d_model, config.n_cls, config.n_layers_cls) |
|
|
|
|
|
self.gradient_checkpointing = config.gradient_checkpointing |
|
|
self.cls_loss = config.cls_loss |
|
|
self.masking_loss = config.masking_loss |
|
|
self.decoding_loss = config.decoding_loss |
|
|
self.return_all_embs = False |
|
|
self.return_cell_embs_first_token = True |
|
|
self.return_cell_embs_all_tokens_mean = False |
|
|
self.init_weights() |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
def add_cls_decoder(self, d_model, n_cls, nlayers): |
|
|
self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers) |
|
|
for m in self.cls_decoder.modules(): |
|
|
self._init_weights(m) |
|
|
self.config.n_cls = n_cls |
|
|
self.config.n_layers_cls = nlayers |
|
|
|
|
|
def add_classification_head(self, d_model, n_cls, nlayers): |
|
|
self.cls_decoder = ClassificationHead(self.config, n_cls, nlayers) |
|
|
for m in self.cls_decoder.modules(): |
|
|
self._init_weights(m) |
|
|
self.config.n_cls = n_cls |
|
|
self.config.n_layers_cls = nlayers |
|
|
|
|
|
def extend_token_embeddings(self): |
|
|
self.config.ntoken += 1 |
|
|
device = self.embeddings.weight.device |
|
|
new_gene_embeddings = nn.Embedding(self.config.ntoken, self.config.d_model) |
|
|
new_gene_embeddings.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
new_gene_embeddings.weight = new_gene_embeddings.weight.to(device) |
|
|
|
|
|
new_decoder_head = nn.Linear(self.config.d_model, self.config.ntoken, bias=False) |
|
|
new_decoder_head.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
new_decoder_head.weight = new_decoder_head.weight.to(device) |
|
|
|
|
|
new_bias = nn.Parameter(torch.zeros(self.config.ntoken)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
new_gene_embeddings.weight[:-1, :] = self.embeddings.weight |
|
|
self.embeddings = new_gene_embeddings |
|
|
|
|
|
new_decoder_head.weight[:-1, :] = self.decoder_head.weight |
|
|
self.decoder_head = new_decoder_head |
|
|
|
|
|
new_bias[:-1] = self.decoder_bias |
|
|
self.decoder_bias = new_bias |
|
|
|
|
|
def run_layer(self, index): |
|
|
def custom_forward(*inputs): |
|
|
return self.encoder.layers[index]( |
|
|
src=inputs[0], |
|
|
src_key_padding_mask=inputs[1], |
|
|
) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
gene_ids: Tensor, |
|
|
labels: Optional[Tensor] = None, |
|
|
annotations: Optional[Tensor] = None, |
|
|
annotation_labels: Optional[Tensor] = None, |
|
|
annotation_attention_mask: Optional[Tensor] = None, |
|
|
position_ids: Optional[Tensor] = None, |
|
|
attention_mask: Optional[Tensor] = None, |
|
|
return_outputs: Optional[bool] = False, |
|
|
**kwargs, |
|
|
) -> Mapping[str, Tensor]: |
|
|
""" |
|
|
Args: |
|
|
gene_ids: token ids, shape [batch_size, seq_len] |
|
|
annotations: [disease, cell type, tissue type, sex] |
|
|
annotation_labels: [disease, cell type, tissue type, sex] |
|
|
attention_mask: mask for gene_ids, shape [batch_size, seq_len] |
|
|
annotation_attention_mask: mask for annotation labels |
|
|
|
|
|
Returns: |
|
|
dict of output Tensors. |
|
|
""" |
|
|
gene_ids = gene_ids.long() |
|
|
|
|
|
embeddings = self.embeddings(gene_ids) |
|
|
if position_ids is None: |
|
|
position_ids = torch.arange(0, gene_ids.shape[1], device=self.position_embeddings.weight.device) |
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
embeddings += position_embeddings |
|
|
|
|
|
if annotations is not None: |
|
|
annotations = annotations.long() |
|
|
annotation_embeddings = self.embeddings(annotations) |
|
|
embeddings = torch.cat([annotation_embeddings, embeddings], dim=1) |
|
|
else: |
|
|
annotations = torch.empty(0, device=gene_ids.device).long() |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.bool() |
|
|
attention_mask = ~attention_mask |
|
|
else: |
|
|
attention_mask = gene_ids == self.config.pad_token_id |
|
|
if annotation_attention_mask is not None: |
|
|
annotation_attention_mask = annotation_attention_mask.bool() |
|
|
annotation_attention_mask = ~annotation_attention_mask |
|
|
else: |
|
|
annotation_attention_mask = torch.empty(0, device=gene_ids.device) |
|
|
|
|
|
attention_mask = torch.cat([annotation_attention_mask, attention_mask], dim=1) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
transformer_output = embeddings |
|
|
for index in range(len(self.encoder.layers)): |
|
|
transformer_output = torch.utils.checkpoint.checkpoint( |
|
|
self.run_layer(index), transformer_output, attention_mask, use_reentrant=True |
|
|
) |
|
|
else: |
|
|
transformer_output = embeddings |
|
|
for layer in self.encoder.layers: |
|
|
transformer_output = layer(src=transformer_output, src_key_padding_mask=attention_mask) |
|
|
|
|
|
output = {} |
|
|
cell_emb = transformer_output[:, 0, :] |
|
|
|
|
|
if self.return_cell_embs_first_token: |
|
|
output["cell_emb"] = cell_emb |
|
|
|
|
|
if self.return_cell_embs_all_tokens_mean: |
|
|
output["cell_emb_mean"] = transformer_output.mean(dim=1) |
|
|
|
|
|
if self.return_all_embs: |
|
|
output["all_embs"] = transformer_output |
|
|
|
|
|
if self.masking_loss: |
|
|
if labels is not None: |
|
|
labels = labels.long() |
|
|
|
|
|
logits = self.decoder_head(transformer_output) + self.decoder_bias |
|
|
|
|
|
if annotation_labels is not None: |
|
|
all_labels = torch.cat([annotation_labels.long(), labels], dim=1) |
|
|
else: |
|
|
if annotations.shape[0] > 0: |
|
|
raise ValueError("Got annotations and masking loss but not annotation labels were provided") |
|
|
|
|
|
if return_outputs: |
|
|
modeling_logits = logits[:, annotations.shape[1] :] |
|
|
|
|
|
label_positions = labels != -100 |
|
|
flat_positions = label_positions.flatten(0, -1) |
|
|
flat_labels = labels.flatten(0, -1) |
|
|
masked_labels = flat_labels[flat_positions].long() |
|
|
|
|
|
flat_logits = modeling_logits.flatten(0, -2) |
|
|
flat_logits = flat_logits[flat_positions] |
|
|
nlls = -F.log_softmax(flat_logits, dim=1) |
|
|
nlls = torch.gather(input=nlls, dim=-1, index=masked_labels.unsqueeze(-1)).squeeze(-1) |
|
|
output["modeling_nlls"] = nlls |
|
|
output["modeling_predictions"] = torch.argmax(flat_logits, dim=-1) |
|
|
output["masked_labels"] = masked_labels |
|
|
|
|
|
annotation_logits = logits[:, : annotations.shape[1]] |
|
|
annotation_label_positions = annotation_labels != -100 |
|
|
flat_annotation_positions = annotation_label_positions.flatten(0, -1) |
|
|
flat_annotation_labels = annotation_labels.flatten(0, -1) |
|
|
masked_annotation_labels = flat_annotation_labels[flat_annotation_positions].long() |
|
|
|
|
|
flat_annotation_logits = annotation_logits.flatten(0, -2) |
|
|
flat_annotation_logits = flat_annotation_logits[flat_annotation_positions] |
|
|
annotation_nlls = -F.log_softmax(flat_annotation_logits, dim=1) |
|
|
annotation_nlls = torch.gather( |
|
|
input=annotation_nlls, dim=-1, index=masked_annotation_labels.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
output["annotation_nlls"] = annotation_nlls |
|
|
output["annotation_predictions"] = torch.argmax(flat_annotation_logits, dim=-1) |
|
|
output["masked_annotation_labels"] = masked_annotation_labels |
|
|
|
|
|
for n, u_annot in enumerate(["disease", "tissue", "cell_type", "sex"]): |
|
|
u_annotation_labels = annotation_labels[:, n] |
|
|
u_annotation_label_positions = u_annotation_labels != -100 |
|
|
masked_u_annotation_labels = u_annotation_labels[u_annotation_label_positions].long() |
|
|
|
|
|
u_annotation_logits = annotation_logits[:, n] |
|
|
u_annotation_logits = u_annotation_logits[u_annotation_label_positions] |
|
|
u_annotation_nlls = -F.log_softmax(u_annotation_logits, dim=1) |
|
|
u_annotation_nlls = torch.gather( |
|
|
input=u_annotation_nlls, dim=-1, index=masked_u_annotation_labels.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
output[f"{u_annot}_nlls"] = u_annotation_nlls |
|
|
output[f"{u_annot}_predictions"] = torch.argmax(u_annotation_logits, dim=-1) |
|
|
output[f"masked_{u_annot}_labels"] = masked_u_annotation_labels |
|
|
|
|
|
cross_entropies = F.cross_entropy( |
|
|
logits.view(-1, self.config.ntoken), all_labels.view(-1), reduction="none" |
|
|
) |
|
|
cross_entropies = cross_entropies.view(logits.shape[:-1]) |
|
|
annotation_ce = cross_entropies[:, : annotations.shape[1]] |
|
|
modeling_ce = cross_entropies[:, annotations.shape[1] :] |
|
|
output["annotation_loss"] = annotation_ce[annotation_labels != -100].mean() |
|
|
output["modeling_loss"] = modeling_ce[labels != -100].mean() |
|
|
if self.config.annotation_loss_weight is not None and self.config.modeling_loss_weight is not None: |
|
|
output["loss"] = ( |
|
|
self.config.annotation_loss_weight * output["annotation_loss"] |
|
|
+ self.config.modeling_loss_weight * output["modeling_loss"] |
|
|
) |
|
|
else: |
|
|
output["loss"] = cross_entropies[all_labels != -100].mean() |
|
|
|
|
|
if self.config.n_cls > 1: |
|
|
output["cls_output"] = self.cls_decoder(cell_emb) |
|
|
if self.cls_loss and labels is not None: |
|
|
output["loss"] = F.cross_entropy(output["cls_output"]["output"], labels.long()) |
|
|
|
|
|
if self.decoding_loss: |
|
|
logits = logits = self.decoder_head(output["cell_emb"]) + self.decoder_bias |
|
|
output["cls_output"] = { |
|
|
"output": F.log_softmax(logits, dim=-1) |
|
|
} |
|
|
output["loss"] = F.cross_entropy(logits, annotation_labels[:, 0].long()) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class TeddyGModelAnalysis(TeddyGModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
if config.n_cls > 1: |
|
|
self.cls_decoder = ClassificationHeadAnalysis(config, config.n_cls, config.n_layers_cls) |
|
|
|