""" Module: model.py This module defines the `TeddyGModel`, a transformer-based architecture designed for single-cell biology tasks. The model is built on top of Hugging Face's `PreTrainedModel` and includes custom configurations, embeddings, and classification heads to handle gene expression data and biological annotations. Main Features: - **TeddyGConfig**: A configuration class for specifying model hyperparameters such as the number of tokens, embedding dimensions, number of layers, and loss weights. - **TeddyGModel**: The main transformer-based model that supports: - Gene token embeddings and position embeddings. - Biological annotation embeddings (e.g., disease, tissue, cell type, sex). - Masked language modeling and annotation classification losses. - Gradient checkpointing for memory efficiency during training. - Customizable classification heads for downstream tasks. - **TeddyGModelAnalysis**: A subclass of `TeddyGModel` with additional functionality for analysis tasks. Dependencies: - PyTorch: For defining and training the model. - Transformers: For leveraging Hugging Face's `PreTrainedModel` and `PretrainedConfig`. - Torch.nn: For building neural network layers and components. Usage: 1. Define a `TeddyGConfig` object with the desired hyperparameters. 2. Initialize a `TeddyGModel` using the configuration. 3. Use the model for tasks such as masked language modeling, annotation classification, or embedding extraction. Example: ```python from teddy.models.teddy_g.model import TeddyGConfig, TeddyGModel # Define the configuration config = TeddyGConfig(...) # Initialize the model model = TeddyGModel(config) """ from typing import Mapping, Optional import torch import torch.nn.functional as F from torch import Tensor, nn from torch.nn import TransformerEncoder, TransformerEncoderLayer from transformers import PretrainedConfig, PreTrainedModel from teddy.models.classification_heads import ( ClassificationHead, ClassificationHeadAnalysis, ClsDecoder, ) class TeddyGConfig(PretrainedConfig): def __init__( self, annotation_loss_weight: Optional[float] = None, modeling_loss_weight: Optional[float] = None, ntoken: int = 25472, max_position_embeddings: int = 1500, nlayers: int = 12, nheads: int = 16, d_model: int = 512, d_hid: int = 1024, layer_activation="relu", n_layers_cls: int = 0, n_cls: int = 0, dropout: float = 0.0, initializer_range=0.02, pad_token_id: int = -100, pre_norm: bool = False, cls_loss=False, masking_loss=False, decoding_loss=False, gradient_checkpointing=False, **kwargs, ): super().__init__(**kwargs) self.annotation_loss_weight = annotation_loss_weight self.modeling_loss_weight = modeling_loss_weight self.ntoken = ntoken self.d_model = d_model self.nheads = nheads self.d_hid = d_hid self.nlayers = nlayers self.layer_activation = layer_activation self.n_layers_cls = n_layers_cls self.n_cls = n_cls self.dropout = dropout self.initializer_range = initializer_range self.pad_value = pad_token_id self.pre_norm = pre_norm self.cls_loss = cls_loss self.decoding_loss = decoding_loss self.masking_loss = masking_loss self.max_position_embeddings = max_position_embeddings self.gradient_checkpointing = gradient_checkpointing self.architectures = ["TeddyGModel"] self.model_type = "teddy_g" class TeddyGModel(PreTrainedModel): def __init__( self, config: TeddyGConfig, ): super().__init__(config) self.config = config self.embeddings = nn.Embedding(config.ntoken, config.d_model) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model) encoder_layers = TransformerEncoderLayer( config.d_model, config.nheads, config.d_hid, config.dropout, batch_first=True, norm_first=config.pre_norm, activation=config.layer_activation, ) self.encoder = TransformerEncoder(encoder_layers, config.nlayers) self.decoder_head = nn.Linear(config.d_model, config.ntoken, bias=False) self.decoder_bias = nn.Parameter(torch.zeros(config.ntoken)) if config.n_cls > 0: self.add_classification_head(config.d_model, config.n_cls, config.n_layers_cls) self.gradient_checkpointing = config.gradient_checkpointing self.cls_loss = config.cls_loss self.masking_loss = config.masking_loss self.decoding_loss = config.decoding_loss self.return_all_embs = False self.return_cell_embs_first_token = True # return first token slice self.return_cell_embs_all_tokens_mean = False self.init_weights() def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def add_cls_decoder(self, d_model, n_cls, nlayers): self.cls_decoder = ClsDecoder(d_model, n_cls, nlayers) for m in self.cls_decoder.modules(): self._init_weights(m) self.config.n_cls = n_cls self.config.n_layers_cls = nlayers def add_classification_head(self, d_model, n_cls, nlayers): self.cls_decoder = ClassificationHead(self.config, n_cls, nlayers) for m in self.cls_decoder.modules(): self._init_weights(m) self.config.n_cls = n_cls self.config.n_layers_cls = nlayers def extend_token_embeddings(self): self.config.ntoken += 1 device = self.embeddings.weight.device new_gene_embeddings = nn.Embedding(self.config.ntoken, self.config.d_model) new_gene_embeddings.weight.data.normal_(mean=0.0, std=self.config.initializer_range) new_gene_embeddings.weight = new_gene_embeddings.weight.to(device) new_decoder_head = nn.Linear(self.config.d_model, self.config.ntoken, bias=False) new_decoder_head.weight.data.normal_(mean=0.0, std=self.config.initializer_range) new_decoder_head.weight = new_decoder_head.weight.to(device) new_bias = nn.Parameter(torch.zeros(self.config.ntoken)) with torch.no_grad(): new_gene_embeddings.weight[:-1, :] = self.embeddings.weight self.embeddings = new_gene_embeddings new_decoder_head.weight[:-1, :] = self.decoder_head.weight self.decoder_head = new_decoder_head new_bias[:-1] = self.decoder_bias self.decoder_bias = new_bias def run_layer(self, index): def custom_forward(*inputs): return self.encoder.layers[index]( src=inputs[0], # hidden_states src_key_padding_mask=inputs[1], # attention_mask ) return custom_forward def forward( self, gene_ids: Tensor, labels: Optional[Tensor] = None, annotations: Optional[Tensor] = None, annotation_labels: Optional[Tensor] = None, annotation_attention_mask: Optional[Tensor] = None, position_ids: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, return_outputs: Optional[bool] = False, **kwargs, ) -> Mapping[str, Tensor]: """ Args: gene_ids: token ids, shape [batch_size, seq_len] annotations: [disease, cell type, tissue type, sex] annotation_labels: [disease, cell type, tissue type, sex] attention_mask: mask for gene_ids, shape [batch_size, seq_len] annotation_attention_mask: mask for annotation labels Returns: dict of output Tensors. """ gene_ids = gene_ids.long() embeddings = self.embeddings(gene_ids) if position_ids is None: position_ids = torch.arange(0, gene_ids.shape[1], device=self.position_embeddings.weight.device) position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings if annotations is not None: annotations = annotations.long() annotation_embeddings = self.embeddings(annotations) embeddings = torch.cat([annotation_embeddings, embeddings], dim=1) else: annotations = torch.empty(0, device=gene_ids.device).long() # attention masks if attention_mask is not None: attention_mask = attention_mask.bool() attention_mask = ~attention_mask # pytorch TransformerEncoder uses opposite convention from huggingface else: attention_mask = gene_ids == self.config.pad_token_id if annotation_attention_mask is not None: annotation_attention_mask = annotation_attention_mask.bool() annotation_attention_mask = ~annotation_attention_mask else: annotation_attention_mask = torch.empty(0, device=gene_ids.device) attention_mask = torch.cat([annotation_attention_mask, attention_mask], dim=1) if self.gradient_checkpointing and self.training: transformer_output = embeddings for index in range(len(self.encoder.layers)): transformer_output = torch.utils.checkpoint.checkpoint( self.run_layer(index), transformer_output, attention_mask, use_reentrant=True ) else: transformer_output = embeddings for layer in self.encoder.layers: transformer_output = layer(src=transformer_output, src_key_padding_mask=attention_mask) output = {} cell_emb = transformer_output[:, 0, :] if self.return_cell_embs_first_token: output["cell_emb"] = cell_emb # (batch, embsize) if self.return_cell_embs_all_tokens_mean: output["cell_emb_mean"] = transformer_output.mean(dim=1) # (batch, embsize) if self.return_all_embs: output["all_embs"] = transformer_output if self.masking_loss: if labels is not None: labels = labels.long() logits = self.decoder_head(transformer_output) + self.decoder_bias if annotation_labels is not None: all_labels = torch.cat([annotation_labels.long(), labels], dim=1) else: if annotations.shape[0] > 0: raise ValueError("Got annotations and masking loss but not annotation labels were provided") if return_outputs: modeling_logits = logits[:, annotations.shape[1] :] label_positions = labels != -100 flat_positions = label_positions.flatten(0, -1) # (total_len) flat_labels = labels.flatten(0, -1) masked_labels = flat_labels[flat_positions].long() flat_logits = modeling_logits.flatten(0, -2) flat_logits = flat_logits[flat_positions] nlls = -F.log_softmax(flat_logits, dim=1) nlls = torch.gather(input=nlls, dim=-1, index=masked_labels.unsqueeze(-1)).squeeze(-1) output["modeling_nlls"] = nlls # (seq,) output["modeling_predictions"] = torch.argmax(flat_logits, dim=-1) # (seq,) output["masked_labels"] = masked_labels annotation_logits = logits[:, : annotations.shape[1]] annotation_label_positions = annotation_labels != -100 flat_annotation_positions = annotation_label_positions.flatten(0, -1) # (total_len) flat_annotation_labels = annotation_labels.flatten(0, -1) masked_annotation_labels = flat_annotation_labels[flat_annotation_positions].long() flat_annotation_logits = annotation_logits.flatten(0, -2) flat_annotation_logits = flat_annotation_logits[flat_annotation_positions] annotation_nlls = -F.log_softmax(flat_annotation_logits, dim=1) annotation_nlls = torch.gather( input=annotation_nlls, dim=-1, index=masked_annotation_labels.unsqueeze(-1) ).squeeze(-1) output["annotation_nlls"] = annotation_nlls # (seq,) output["annotation_predictions"] = torch.argmax(flat_annotation_logits, dim=-1) # (seq,) output["masked_annotation_labels"] = masked_annotation_labels for n, u_annot in enumerate(["disease", "tissue", "cell_type", "sex"]): u_annotation_labels = annotation_labels[:, n] u_annotation_label_positions = u_annotation_labels != -100 masked_u_annotation_labels = u_annotation_labels[u_annotation_label_positions].long() u_annotation_logits = annotation_logits[:, n] # (batch, dim) u_annotation_logits = u_annotation_logits[u_annotation_label_positions] u_annotation_nlls = -F.log_softmax(u_annotation_logits, dim=1) u_annotation_nlls = torch.gather( input=u_annotation_nlls, dim=-1, index=masked_u_annotation_labels.unsqueeze(-1) ).squeeze(-1) output[f"{u_annot}_nlls"] = u_annotation_nlls # (seq,) output[f"{u_annot}_predictions"] = torch.argmax(u_annotation_logits, dim=-1) # (seq,) output[f"masked_{u_annot}_labels"] = masked_u_annotation_labels cross_entropies = F.cross_entropy( logits.view(-1, self.config.ntoken), all_labels.view(-1), reduction="none" ) # (seq len,) cross_entropies = cross_entropies.view(logits.shape[:-1]) annotation_ce = cross_entropies[:, : annotations.shape[1]] modeling_ce = cross_entropies[:, annotations.shape[1] :] output["annotation_loss"] = annotation_ce[annotation_labels != -100].mean() output["modeling_loss"] = modeling_ce[labels != -100].mean() if self.config.annotation_loss_weight is not None and self.config.modeling_loss_weight is not None: output["loss"] = ( self.config.annotation_loss_weight * output["annotation_loss"] + self.config.modeling_loss_weight * output["modeling_loss"] ) else: output["loss"] = cross_entropies[all_labels != -100].mean() if self.config.n_cls > 1: output["cls_output"] = self.cls_decoder(cell_emb) # (batch, n_cls) if self.cls_loss and labels is not None: output["loss"] = F.cross_entropy(output["cls_output"]["output"], labels.long()) if self.decoding_loss: logits = logits = self.decoder_head(output["cell_emb"]) + self.decoder_bias output["cls_output"] = { "output": F.log_softmax(logits, dim=-1) } # NOTE: only implemented for disease classification output["loss"] = F.cross_entropy(logits, annotation_labels[:, 0].long()) return output class TeddyGModelAnalysis(TeddyGModel): def __init__(self, config): super().__init__(config) if config.n_cls > 1: self.cls_decoder = ClassificationHeadAnalysis(config, config.n_cls, config.n_layers_cls)