|
|
import torch |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from genomics_research.biobrain_p1.porting_to_pytorch.configs.chatNT_config import ( |
|
|
ChatNTConfig, |
|
|
) |
|
|
from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_decoder import ( |
|
|
TorchBioBrainDecoder, |
|
|
) |
|
|
from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_encoder import ( |
|
|
TorchBioBrainEncoder, |
|
|
) |
|
|
from genomics_research.biobrain_p1.porting_to_pytorch.models.perceiver_resampler_projection import ( |
|
|
TorchMultiModalPerceiverResamplerProjection, |
|
|
) |
|
|
|
|
|
|
|
|
class TorchMultiOmicsModel(PreTrainedModel): |
|
|
config_class = ChatNTConfig |
|
|
|
|
|
def __init__(self, config: ChatNTConfig) -> None: |
|
|
super().__init__(config=config) |
|
|
self.gpt_config = config.gpt_config |
|
|
self.esm_config = config.esm_config |
|
|
self.perceiver_resampler_config = config.perceiver_resampler_config |
|
|
self.seq_token_id = config.seq_token_id |
|
|
self.bio_pad_token_id = config.bio_pad_token_id |
|
|
self.english_pad_token_id = config.english_pad_token_id |
|
|
|
|
|
|
|
|
self.seq_token_id -= 1 |
|
|
|
|
|
self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config) |
|
|
self.biobrain_decoder = TorchBioBrainDecoder( |
|
|
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id |
|
|
) |
|
|
self.projection_model = TorchMultiModalPerceiverResamplerProjection( |
|
|
perceiver_resampler_config=self.perceiver_resampler_config, |
|
|
input_embed_dim=self.esm_config.embed_dim, |
|
|
embed_dim=self.gpt_config.embed_dim, |
|
|
english_vocab_size=self.gpt_config.vocab_size, |
|
|
bio_pad_token_id=self.bio_pad_token_id, |
|
|
english_pad_token_id=self.english_pad_token_id, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor], |
|
|
projection_english_tokens_ids: torch.Tensor, |
|
|
projected_bio_embeddings: torch.Tensor = None, |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
|
|
|
Args: |
|
|
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]): |
|
|
english_tokens_ids: Represents the prompt tokens (english tokens) |
|
|
Shape (batch_size, num_english_tokens) |
|
|
|
|
|
bio_tokens_ids: Represents the bio sequences tokens |
|
|
Shape (batch_size, num_bio_sequences, num_bio_tokens) |
|
|
|
|
|
projection_english_tokens_ids (torch.Tensor): |
|
|
Shape (batch_size, num_english_tokens) |
|
|
|
|
|
projected_bio_embeddings (projected_bio_embeddings, optional): |
|
|
Shape (batch_size, num_bio_sequencse, ?, embed_dim). |
|
|
Defaults to None. |
|
|
|
|
|
Returns: |
|
|
dict[str, torch.Tensor] containing: |
|
|
- logits: |
|
|
Shape (batch_size, num_tokens, vocab_size) |
|
|
|
|
|
- projected_bio_embeddings: |
|
|
Shape (batch_size, num_bio_sequences, ?, embed_dim) |
|
|
""" |
|
|
english_token_ids, bio_token_ids = multi_omics_tokens_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab_size = self.gpt_config.vocab_size |
|
|
|
|
|
english_token_ids[english_token_ids == vocab_size - 1] = 0 |
|
|
projection_english_tokens_ids[ |
|
|
projection_english_tokens_ids == vocab_size - 1 |
|
|
] = 0 |
|
|
english_token_ids[english_token_ids == vocab_size] = vocab_size - 1 |
|
|
projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = ( |
|
|
vocab_size - 1 |
|
|
) |
|
|
|
|
|
if bio_token_ids is None: |
|
|
projected_bio_embeddings = None |
|
|
else: |
|
|
num_bio_sequences = bio_token_ids.shape[1] |
|
|
|
|
|
if projected_bio_embeddings is None: |
|
|
|
|
|
bio_embeddings_list = [ |
|
|
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num]) |
|
|
for bio_seq_num in range(num_bio_sequences) |
|
|
] |
|
|
|
|
|
|
|
|
projected_bio_embeddings = [ |
|
|
self.projection_model( |
|
|
bio_token_ids=bio_token_ids[:, bio_seq_num], |
|
|
bio_embeddings=bio_embeddings, |
|
|
english_token_ids=projection_english_tokens_ids, |
|
|
) |
|
|
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list) |
|
|
] |
|
|
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1) |
|
|
|
|
|
|
|
|
logits = self.biobrain_decoder( |
|
|
english_token_ids=english_token_ids, |
|
|
projected_bio_embeddings=projected_bio_embeddings, |
|
|
) |
|
|
|
|
|
outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings} |
|
|
|
|
|
return outs |
|
|
|