Safetensors
soumyatghosh's picture
Upload folder using huggingface_hub
4527b5f verified
"""
Module: model.py
This module defines the `TeddyGModel`, a transformer-based architecture designed for single-cell biology tasks.
The model is built on top of Hugging Face's `PreTrainedModel` and includes custom configurations, embeddings,
and classification heads to handle gene expression data and biological annotations.
Main Features:
- **TeddyGConfig**: A configuration class for specifying model hyperparameters such as the number of tokens,
embedding dimensions, number of layers, and loss weights.
- **TeddyGModel**: The main transformer-based model that supports:
- Gene token embeddings and position embeddings.
- Biological annotation embeddings (e.g., disease, tissue, cell type, sex).
- Masked language modeling and annotation classification losses.
- Gradient checkpointing for memory efficiency during training.
- Customizable classification heads for downstream tasks.
- **TeddyGModelAnalysis**: A subclass of `TeddyGModel` with additional functionality for analysis tasks.
Dependencies:
- PyTorch: For defining and training the model.
- Transformers: For leveraging Hugging Face's `PreTrainedModel` and `PretrainedConfig`.
- Torch.nn: For building neural network layers and components.
Usage:
1. Define a `TeddyGConfig` object with the desired hyperparameters.
2. Initialize a `TeddyGModel` using the configuration.
3. Use the model for tasks such as masked language modeling, annotation classification, or embedding extraction.
Example:
```python
from teddy.models.teddy_g.model import TeddyGConfig, TeddyGModel
# Define the configuration
config = TeddyGConfig(...)
# Initialize the model
model = TeddyGModel(config)
"""
from typing import Mapping, Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from transformers import PretrainedConfig, PreTrainedModel
from teddy.models.classification_heads import (
ClassificationHead,
ClassificationHeadAnalysis,
ClsDecoder,
)
class TeddyGConfig(PretrainedConfig):
def __init__(
self,
annotation_loss_weight: Optional[float] = None,
modeling_loss_weight: Optional[float] = None,
ntoken: int = 25472,
max_position_embeddings: int = 1500,
nlayers: int = 12,
nheads: int = 16,
d_model: int = 512,
d_hid: int = 1024,
layer_activation="relu",
n_layers_cls: int = 0,
n_cls: int = 0,
dropout: float = 0.0,
initializer_range=0.02,
pad_token_id: int = -100,
pre_norm: bool = False,
cls_loss=False,
masking_loss=False,
decoding_loss=False,
gradient_checkpointing=False,
**kwargs,
):
super().__init__(**kwargs)
self.annotation_loss_weight = annotation_loss_weight
self.modeling_loss_weight = modeling_loss_weight
self.ntoken = ntoken
self.d_model = d_model
self.nheads = nheads
self.d_hid = d_hid
self.nlayers = nlayers
self.layer_activation = layer_activation
self.n_layers_cls = n_layers_cls
self.n_cls = n_cls
self.dropout = dropout
self.initializer_range = initializer_range
self.pad_value = pad_token_id
self.pre_norm = pre_norm
self.cls_loss = cls_loss
self.decoding_loss = decoding_loss
self.masking_loss = masking_loss
self.max_position_embeddings = max_position_embeddings
self.gradient_checkpointing = gradient_checkpointing
self.architectures = ["TeddyGModel"]
self.model_type = "teddy_g"
class TeddyGModel(PreTrainedModel):
def __init__(
self,
config: TeddyGConfig,
):
super().__init__(config)
self.config = config
self.embeddings = nn.Embedding(config.ntoken, config.d_model)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
encoder_layers = TransformerEncoderLayer(
config.d_model,
config.nheads,
config.d_hid,
config.dropout,
batch_first=True,
norm_first=config.pre_norm,
activation=config.layer_activation,
)
self.encoder = TransformerEncoder(encoder_layers, config.nlayers)
self.decoder_head = nn.Linear(config.d_model, config.ntoken, bias=False)
self.decoder_bias = nn.Parameter(torch.zeros(config.ntoken))
if config.n_cls > 0:
self.add_classification_head(config.d_model, config.n_cls, config.n_layers_cls)
self.gradient_checkpointing = config.gradient_checkpointing
self.cls_loss = config.cls_loss
self.masking_loss = config.masking_loss
self.decoding_loss = config.decoding_loss
self.return_all_embs = False
self.return_cell_embs_first_token = True # return first token slice
self.return_cell_embs_all_tokens_mean = False
self.init_weights()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def add_cls_decoder(self, d_model, n_cls, nlayers):
self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers)
for m in self.cls_decoder.modules():
self._init_weights(m)
self.config.n_cls = n_cls
self.config.n_layers_cls = nlayers
def add_classification_head(self, d_model, n_cls, nlayers):
self.cls_decoder = ClassificationHead(self.config, n_cls, nlayers)
for m in self.cls_decoder.modules():
self._init_weights(m)
self.config.n_cls = n_cls
self.config.n_layers_cls = nlayers
def extend_token_embeddings(self):
self.config.ntoken += 1
device = self.embeddings.weight.device
new_gene_embeddings = nn.Embedding(self.config.ntoken, self.config.d_model)
new_gene_embeddings.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
new_gene_embeddings.weight = new_gene_embeddings.weight.to(device)
new_decoder_head = nn.Linear(self.config.d_model, self.config.ntoken, bias=False)
new_decoder_head.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
new_decoder_head.weight = new_decoder_head.weight.to(device)
new_bias = nn.Parameter(torch.zeros(self.config.ntoken))
with torch.no_grad():
new_gene_embeddings.weight[:-1, :] = self.embeddings.weight
self.embeddings = new_gene_embeddings
new_decoder_head.weight[:-1, :] = self.decoder_head.weight
self.decoder_head = new_decoder_head
new_bias[:-1] = self.decoder_bias
self.decoder_bias = new_bias
def run_layer(self, index):
def custom_forward(*inputs):
return self.encoder.layers[index](
src=inputs[0], # hidden_states
src_key_padding_mask=inputs[1], # attention_mask
)
return custom_forward
def forward(
self,
gene_ids: Tensor,
labels: Optional[Tensor] = None,
annotations: Optional[Tensor] = None,
annotation_labels: Optional[Tensor] = None,
annotation_attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
return_outputs: Optional[bool] = False,
**kwargs,
) -> Mapping[str, Tensor]:
"""
Args:
gene_ids: token ids, shape [batch_size, seq_len]
annotations: [disease, cell type, tissue type, sex]
annotation_labels: [disease, cell type, tissue type, sex]
attention_mask: mask for gene_ids, shape [batch_size, seq_len]
annotation_attention_mask: mask for annotation labels
Returns:
dict of output Tensors.
"""
gene_ids = gene_ids.long()
embeddings = self.embeddings(gene_ids)
if position_ids is None:
position_ids = torch.arange(0, gene_ids.shape[1], device=self.position_embeddings.weight.device)
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
if annotations is not None:
annotations = annotations.long()
annotation_embeddings = self.embeddings(annotations)
embeddings = torch.cat([annotation_embeddings, embeddings], dim=1)
else:
annotations = torch.empty(0, device=gene_ids.device).long()
# attention masks
if attention_mask is not None:
attention_mask = attention_mask.bool()
attention_mask = ~attention_mask # pytorch TransformerEncoder uses opposite convention from huggingface
else:
attention_mask = gene_ids == self.config.pad_token_id
if annotation_attention_mask is not None:
annotation_attention_mask = annotation_attention_mask.bool()
annotation_attention_mask = ~annotation_attention_mask
else:
annotation_attention_mask = torch.empty(0, device=gene_ids.device)
attention_mask = torch.cat([annotation_attention_mask, attention_mask], dim=1)
if self.gradient_checkpointing and self.training:
transformer_output = embeddings
for index in range(len(self.encoder.layers)):
transformer_output = torch.utils.checkpoint.checkpoint(
self.run_layer(index), transformer_output, attention_mask, use_reentrant=True
)
else:
transformer_output = embeddings
for layer in self.encoder.layers:
transformer_output = layer(src=transformer_output, src_key_padding_mask=attention_mask)
output = {}
cell_emb = transformer_output[:, 0, :]
if self.return_cell_embs_first_token:
output["cell_emb"] = cell_emb # (batch, embsize)
if self.return_cell_embs_all_tokens_mean:
output["cell_emb_mean"] = transformer_output.mean(dim=1) # (batch, embsize)
if self.return_all_embs:
output["all_embs"] = transformer_output
if self.masking_loss:
if labels is not None:
labels = labels.long()
logits = self.decoder_head(transformer_output) + self.decoder_bias
if annotation_labels is not None:
all_labels = torch.cat([annotation_labels.long(), labels], dim=1)
else:
if annotations.shape[0] > 0:
raise ValueError("Got annotations and masking loss but not annotation labels were provided")
if return_outputs:
modeling_logits = logits[:, annotations.shape[1] :]
label_positions = labels != -100
flat_positions = label_positions.flatten(0, -1) # (total_len)
flat_labels = labels.flatten(0, -1)
masked_labels = flat_labels[flat_positions].long()
flat_logits = modeling_logits.flatten(0, -2)
flat_logits = flat_logits[flat_positions]
nlls = -F.log_softmax(flat_logits, dim=1)
nlls = torch.gather(input=nlls, dim=-1, index=masked_labels.unsqueeze(-1)).squeeze(-1)
output["modeling_nlls"] = nlls # (seq,)
output["modeling_predictions"] = torch.argmax(flat_logits, dim=-1) # (seq,)
output["masked_labels"] = masked_labels
annotation_logits = logits[:, : annotations.shape[1]]
annotation_label_positions = annotation_labels != -100
flat_annotation_positions = annotation_label_positions.flatten(0, -1) # (total_len)
flat_annotation_labels = annotation_labels.flatten(0, -1)
masked_annotation_labels = flat_annotation_labels[flat_annotation_positions].long()
flat_annotation_logits = annotation_logits.flatten(0, -2)
flat_annotation_logits = flat_annotation_logits[flat_annotation_positions]
annotation_nlls = -F.log_softmax(flat_annotation_logits, dim=1)
annotation_nlls = torch.gather(
input=annotation_nlls, dim=-1, index=masked_annotation_labels.unsqueeze(-1)
).squeeze(-1)
output["annotation_nlls"] = annotation_nlls # (seq,)
output["annotation_predictions"] = torch.argmax(flat_annotation_logits, dim=-1) # (seq,)
output["masked_annotation_labels"] = masked_annotation_labels
for n, u_annot in enumerate(["disease", "tissue", "cell_type", "sex"]):
u_annotation_labels = annotation_labels[:, n]
u_annotation_label_positions = u_annotation_labels != -100
masked_u_annotation_labels = u_annotation_labels[u_annotation_label_positions].long()
u_annotation_logits = annotation_logits[:, n] # (batch, dim)
u_annotation_logits = u_annotation_logits[u_annotation_label_positions]
u_annotation_nlls = -F.log_softmax(u_annotation_logits, dim=1)
u_annotation_nlls = torch.gather(
input=u_annotation_nlls, dim=-1, index=masked_u_annotation_labels.unsqueeze(-1)
).squeeze(-1)
output[f"{u_annot}_nlls"] = u_annotation_nlls # (seq,)
output[f"{u_annot}_predictions"] = torch.argmax(u_annotation_logits, dim=-1) # (seq,)
output[f"masked_{u_annot}_labels"] = masked_u_annotation_labels
cross_entropies = F.cross_entropy(
logits.view(-1, self.config.ntoken), all_labels.view(-1), reduction="none"
) # (seq len,)
cross_entropies = cross_entropies.view(logits.shape[:-1])
annotation_ce = cross_entropies[:, : annotations.shape[1]]
modeling_ce = cross_entropies[:, annotations.shape[1] :]
output["annotation_loss"] = annotation_ce[annotation_labels != -100].mean()
output["modeling_loss"] = modeling_ce[labels != -100].mean()
if self.config.annotation_loss_weight is not None and self.config.modeling_loss_weight is not None:
output["loss"] = (
self.config.annotation_loss_weight * output["annotation_loss"]
+ self.config.modeling_loss_weight * output["modeling_loss"]
)
else:
output["loss"] = cross_entropies[all_labels != -100].mean()
if self.config.n_cls > 1:
output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls)
if self.cls_loss and labels is not None:
output["loss"] = F.cross_entropy(output["cls_output"]["output"], labels.long())
if self.decoding_loss:
logits = logits = self.decoder_head(output["cell_emb"]) + self.decoder_bias
output["cls_output"] = {
"output": F.log_softmax(logits, dim=-1)
} # NOTE: only implemented for disease classification
output["loss"] = F.cross_entropy(logits, annotation_labels[:, 0].long())
return output
class TeddyGModelAnalysis(TeddyGModel):
def __init__(self, config):
super().__init__(config)
if config.n_cls > 1:
self.cls_decoder = ClassificationHeadAnalysis(config, config.n_cls, config.n_layers_cls)