File size: 5,215 Bytes
55bbd6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from transformers import PreTrainedModel

from genomics_research.biobrain_p1.porting_to_pytorch.configs.chatNT_config import (
    ChatNTConfig,
)
from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_decoder import (
    TorchBioBrainDecoder,
)
from genomics_research.biobrain_p1.porting_to_pytorch.models.biobrain_encoder import (
    TorchBioBrainEncoder,
)
from genomics_research.biobrain_p1.porting_to_pytorch.models.perceiver_resampler_projection import (  # noqa
    TorchMultiModalPerceiverResamplerProjection,
)


class TorchMultiOmicsModel(PreTrainedModel):
    config_class = ChatNTConfig

    def __init__(self, config: ChatNTConfig) -> None:
        super().__init__(config=config)
        self.gpt_config = config.gpt_config
        self.esm_config = config.esm_config
        self.perceiver_resampler_config = config.perceiver_resampler_config
        self.seq_token_id = config.seq_token_id
        self.bio_pad_token_id = config.bio_pad_token_id
        self.english_pad_token_id = config.english_pad_token_id

        # Correct seq_token_id
        self.seq_token_id -= 1

        self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config)
        self.biobrain_decoder = TorchBioBrainDecoder(
            gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
        )
        self.projection_model = TorchMultiModalPerceiverResamplerProjection(
            perceiver_resampler_config=self.perceiver_resampler_config,
            input_embed_dim=self.esm_config.embed_dim,
            embed_dim=self.gpt_config.embed_dim,
            english_vocab_size=self.gpt_config.vocab_size,
            bio_pad_token_id=self.bio_pad_token_id,
            english_pad_token_id=self.english_pad_token_id,
        )

    def forward(
        self,
        multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
        projection_english_tokens_ids: torch.Tensor,
        projected_bio_embeddings: torch.Tensor = None,
    ) -> dict[str, torch.Tensor]:
        """

        Args:
            multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
                english_tokens_ids: Represents the prompt tokens (english tokens)
                    Shape (batch_size, num_english_tokens)

                bio_tokens_ids: Represents the bio sequences tokens
                    Shape (batch_size, num_bio_sequences, num_bio_tokens)

            projection_english_tokens_ids (torch.Tensor):
                Shape (batch_size, num_english_tokens)

            projected_bio_embeddings (projected_bio_embeddings, optional):
                Shape (batch_size, num_bio_sequencse, ?, embed_dim).
                Defaults to None.

        Returns:
            dict[str, torch.Tensor] containing:
                - logits:
                    Shape (batch_size, num_tokens, vocab_size)

                - projected_bio_embeddings:
                    Shape (batch_size, num_bio_sequences, ?, embed_dim)
        """
        english_token_ids, bio_token_ids = multi_omics_tokens_ids

        # Replace config.vocab_size value in english tokens
        # We do this because the default vocab size (32000) doesn't match with the
        # number of tokens because of seq_token_id(=32000) that was added
        # Therefore, we will put seq_token_id to 31999
        # (I will also put token n°31999 to 0, which is for unknown token)
        # This is a workaround to avoid having to change the vocab size in the config
        vocab_size = self.gpt_config.vocab_size
        # Replace vocab
        english_token_ids[english_token_ids == vocab_size - 1] = 0
        projection_english_tokens_ids[
            projection_english_tokens_ids == vocab_size - 1
        ] = 0
        english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
        projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
            vocab_size - 1
        )

        if bio_token_ids is None:
            projected_bio_embeddings = None
        else:
            num_bio_sequences = bio_token_ids.shape[1]

            if projected_bio_embeddings is None:
                # Compute bio sequences embeddings
                bio_embeddings_list = [
                    self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
                    for bio_seq_num in range(num_bio_sequences)
                ]

                # Project these embeddings
                projected_bio_embeddings = [
                    self.projection_model(
                        bio_token_ids=bio_token_ids[:, bio_seq_num],
                        bio_embeddings=bio_embeddings,
                        english_token_ids=projection_english_tokens_ids,
                    )
                    for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
                ]
                projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)

        # decode
        logits = self.biobrain_decoder(
            english_token_ids=english_token_ids,
            projected_bio_embeddings=projected_bio_embeddings,
        )

        outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}

        return outs