ChatNT / multi_omics_model.py
Yanisadel's picture
Upload model
55bbd6f
raw
history blame
5.22 kB
import torch
from transformers import PreTrainedModel
from genomics_research.biobrain_p1.porting_to_pytorch.configs.chatNT_config import (
ChatNTConfig,
)
from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_decoder import (
TorchBioBrainDecoder,
)
from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_encoder import (
TorchBioBrainEncoder,
)
from genomics_research.biobrain_p1.porting_to_pytorch.models.perceiver_resampler_projection import ( # noqa
TorchMultiModalPerceiverResamplerProjection,
)
class TorchMultiOmicsModel(PreTrainedModel):
config_class = ChatNTConfig
def __init__(self, config: ChatNTConfig) -> None:
super().__init__(config=config)
self.gpt_config = config.gpt_config
self.esm_config = config.esm_config
self.perceiver_resampler_config = config.perceiver_resampler_config
self.seq_token_id = config.seq_token_id
self.bio_pad_token_id = config.bio_pad_token_id
self.english_pad_token_id = config.english_pad_token_id
# Correct seq_token_id
self.seq_token_id -= 1
self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config)
self.biobrain_decoder = TorchBioBrainDecoder(
gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
)
self.projection_model = TorchMultiModalPerceiverResamplerProjection(
perceiver_resampler_config=self.perceiver_resampler_config,
input_embed_dim=self.esm_config.embed_dim,
embed_dim=self.gpt_config.embed_dim,
english_vocab_size=self.gpt_config.vocab_size,
bio_pad_token_id=self.bio_pad_token_id,
english_pad_token_id=self.english_pad_token_id,
)
def forward(
self,
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
projection_english_tokens_ids: torch.Tensor,
projected_bio_embeddings: torch.Tensor = None,
) -> dict[str, torch.Tensor]:
"""
Args:
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
english_tokens_ids: Represents the prompt tokens (english tokens)
Shape (batch_size, num_english_tokens)
bio_tokens_ids: Represents the bio sequences tokens
Shape (batch_size, num_bio_sequences, num_bio_tokens)
projection_english_tokens_ids (torch.Tensor):
Shape (batch_size, num_english_tokens)
projected_bio_embeddings (projected_bio_embeddings, optional):
Shape (batch_size, num_bio_sequencse, ?, embed_dim).
Defaults to None.
Returns:
dict[str, torch.Tensor] containing:
- logits:
Shape (batch_size, num_tokens, vocab_size)
- projected_bio_embeddings:
Shape (batch_size, num_bio_sequences, ?, embed_dim)
"""
english_token_ids, bio_token_ids = multi_omics_tokens_ids
# Replace config.vocab_size value in english tokens
# We do this because the default vocab size (32000) doesn't match with the
# number of tokens because of seq_token_id(=32000) that was added
# Therefore, we will put seq_token_id to 31999
# (I will also put token n°31999 to 0, which is for unknown token)
# This is a workaround to avoid having to change the vocab size in the config
vocab_size = self.gpt_config.vocab_size
# Replace vocab
english_token_ids[english_token_ids == vocab_size - 1] = 0
projection_english_tokens_ids[
projection_english_tokens_ids == vocab_size - 1
] = 0
english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
vocab_size - 1
)
if bio_token_ids is None:
projected_bio_embeddings = None
else:
num_bio_sequences = bio_token_ids.shape[1]
if projected_bio_embeddings is None:
# Compute bio sequences embeddings
bio_embeddings_list = [
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
for bio_seq_num in range(num_bio_sequences)
]
# Project these embeddings
projected_bio_embeddings = [
self.projection_model(
bio_token_ids=bio_token_ids[:, bio_seq_num],
bio_embeddings=bio_embeddings,
english_token_ids=projection_english_tokens_ids,
)
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
]
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
# decode
logits = self.biobrain_decoder(
english_token_ids=english_token_ids,
projected_bio_embeddings=projected_bio_embeddings,
)
outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
return outs