import torch from transformers import PreTrainedModel from genomics_research.biobrain_p1.porting_to_pytorch.configs.chatNT_config import ( ChatNTConfig, ) from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_decoder import ( TorchBioBrainDecoder, ) from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_encoder import ( TorchBioBrainEncoder, ) from genomics_research.biobrain_p1.porting_to_pytorch.models.perceiver_resampler_projection import ( # noqa TorchMultiModalPerceiverResamplerProjection, ) class TorchMultiOmicsModel(PreTrainedModel): config_class = ChatNTConfig def __init__(self, config: ChatNTConfig) -> None: super().__init__(config=config) self.gpt_config = config.gpt_config self.esm_config = config.esm_config self.perceiver_resampler_config = config.perceiver_resampler_config self.seq_token_id = config.seq_token_id self.bio_pad_token_id = config.bio_pad_token_id self.english_pad_token_id = config.english_pad_token_id # Correct seq_token_id self.seq_token_id -= 1 self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config) self.biobrain_decoder = TorchBioBrainDecoder( gpt_config=self.gpt_config, seq_token_id=self.seq_token_id ) self.projection_model = TorchMultiModalPerceiverResamplerProjection( perceiver_resampler_config=self.perceiver_resampler_config, input_embed_dim=self.esm_config.embed_dim, embed_dim=self.gpt_config.embed_dim, english_vocab_size=self.gpt_config.vocab_size, bio_pad_token_id=self.bio_pad_token_id, english_pad_token_id=self.english_pad_token_id, ) def forward( self, multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor], projection_english_tokens_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor = None, ) -> dict[str, torch.Tensor]: """ Args: multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]): english_tokens_ids: Represents the prompt tokens (english tokens) Shape (batch_size, num_english_tokens) bio_tokens_ids: Represents the bio sequences tokens Shape (batch_size, num_bio_sequences, num_bio_tokens) projection_english_tokens_ids (torch.Tensor): Shape (batch_size, num_english_tokens) projected_bio_embeddings (projected_bio_embeddings, optional): Shape (batch_size, num_bio_sequencse, ?, embed_dim). Defaults to None. Returns: dict[str, torch.Tensor] containing: - logits: Shape (batch_size, num_tokens, vocab_size) - projected_bio_embeddings: Shape (batch_size, num_bio_sequences, ?, embed_dim) """ english_token_ids, bio_token_ids = multi_omics_tokens_ids # Replace config.vocab_size value in english tokens # We do this because the default vocab size (32000) doesn't match with the # number of tokens because of seq_token_id(=32000) that was added # Therefore, we will put seq_token_id to 31999 # (I will also put token n°31999 to 0, which is for unknown token) # This is a workaround to avoid having to change the vocab size in the config vocab_size = self.gpt_config.vocab_size # Replace vocab english_token_ids[english_token_ids == vocab_size - 1] = 0 projection_english_tokens_ids[ projection_english_tokens_ids == vocab_size - 1 ] = 0 english_token_ids[english_token_ids == vocab_size] = vocab_size - 1 projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = ( vocab_size - 1 ) if bio_token_ids is None: projected_bio_embeddings = None else: num_bio_sequences = bio_token_ids.shape[1] if projected_bio_embeddings is None: # Compute bio sequences embeddings bio_embeddings_list = [ self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num]) for bio_seq_num in range(num_bio_sequences) ] # Project these embeddings projected_bio_embeddings = [ self.projection_model( bio_token_ids=bio_token_ids[:, bio_seq_num], bio_embeddings=bio_embeddings, english_token_ids=projection_english_tokens_ids, ) for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list) ] projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1) # decode logits = self.biobrain_decoder( english_token_ids=english_token_ids, projected_bio_embeddings=projected_bio_embeddings, ) outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings} return outs