| import torch |
| from transformers import PreTrainedModel |
|
|
| from genomics_research.biobrain_p1.porting_to_pytorch.configs.chatNT_config import ( |
| ChatNTConfig, |
| ) |
| from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_decoder import ( |
| TorchBioBrainDecoder, |
| ) |
| from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_encoder import ( |
| TorchBioBrainEncoder, |
| ) |
| from genomics_research.biobrain_p1.porting_to_pytorch.models.perceiver_resampler_projection import ( |
| TorchMultiModalPerceiverResamplerProjection, |
| ) |
|
|
|
|
| class TorchMultiOmicsModel(PreTrainedModel): |
| config_class = ChatNTConfig |
|
|
| def __init__(self, config: ChatNTConfig) -> None: |
| super().__init__(config=config) |
| self.gpt_config = config.gpt_config |
| self.esm_config = config.esm_config |
| self.perceiver_resampler_config = config.perceiver_resampler_config |
| self.seq_token_id = config.seq_token_id |
| self.bio_pad_token_id = config.bio_pad_token_id |
| self.english_pad_token_id = config.english_pad_token_id |
|
|
| |
| self.seq_token_id -= 1 |
|
|
| self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config) |
| self.biobrain_decoder = TorchBioBrainDecoder( |
| gpt_config=self.gpt_config, seq_token_id=self.seq_token_id |
| ) |
| self.projection_model = TorchMultiModalPerceiverResamplerProjection( |
| perceiver_resampler_config=self.perceiver_resampler_config, |
| input_embed_dim=self.esm_config.embed_dim, |
| embed_dim=self.gpt_config.embed_dim, |
| english_vocab_size=self.gpt_config.vocab_size, |
| bio_pad_token_id=self.bio_pad_token_id, |
| english_pad_token_id=self.english_pad_token_id, |
| ) |
|
|
| def forward( |
| self, |
| multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor], |
| projection_english_tokens_ids: torch.Tensor, |
| projected_bio_embeddings: torch.Tensor = None, |
| ) -> dict[str, torch.Tensor]: |
| """ |
| |
| Args: |
| multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]): |
| english_tokens_ids: Represents the prompt tokens (english tokens) |
| Shape (batch_size, num_english_tokens) |
| |
| bio_tokens_ids: Represents the bio sequences tokens |
| Shape (batch_size, num_bio_sequences, num_bio_tokens) |
| |
| projection_english_tokens_ids (torch.Tensor): |
| Shape (batch_size, num_english_tokens) |
| |
| projected_bio_embeddings (projected_bio_embeddings, optional): |
| Shape (batch_size, num_bio_sequencse, ?, embed_dim). |
| Defaults to None. |
| |
| Returns: |
| dict[str, torch.Tensor] containing: |
| - logits: |
| Shape (batch_size, num_tokens, vocab_size) |
| |
| - projected_bio_embeddings: |
| Shape (batch_size, num_bio_sequences, ?, embed_dim) |
| """ |
| english_token_ids, bio_token_ids = multi_omics_tokens_ids |
|
|
| |
| |
| |
| |
| |
| |
| vocab_size = self.gpt_config.vocab_size |
| |
| english_token_ids[english_token_ids == vocab_size - 1] = 0 |
| projection_english_tokens_ids[ |
| projection_english_tokens_ids == vocab_size - 1 |
| ] = 0 |
| english_token_ids[english_token_ids == vocab_size] = vocab_size - 1 |
| projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = ( |
| vocab_size - 1 |
| ) |
|
|
| if bio_token_ids is None: |
| projected_bio_embeddings = None |
| else: |
| num_bio_sequences = bio_token_ids.shape[1] |
|
|
| if projected_bio_embeddings is None: |
| |
| bio_embeddings_list = [ |
| self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num]) |
| for bio_seq_num in range(num_bio_sequences) |
| ] |
|
|
| |
| projected_bio_embeddings = [ |
| self.projection_model( |
| bio_token_ids=bio_token_ids[:, bio_seq_num], |
| bio_embeddings=bio_embeddings, |
| english_token_ids=projection_english_tokens_ids, |
| ) |
| for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list) |
| ] |
| projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1) |
|
|
| |
| logits = self.biobrain_decoder( |
| english_token_ids=english_token_ids, |
| projected_bio_embeddings=projected_bio_embeddings, |
| ) |
|
|
| outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings} |
|
|
| return outs |
|
|