File size: 3,564 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
We use the FastPLM implementation of DPLM2.
"""
import sys
import os
import torch
import torch.nn as nn
from typing import List, Optional, Union, Dict

_FASTPLMS = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'FastPLMs')
if _FASTPLMS not in sys.path:
    sys.path.insert(0, _FASTPLMS)

from dplm2_fastplms.modeling_dplm2 import (
    DPLM2ForMaskedLM,
    DPLM2ForSequenceClassification,
    DPLM2ForTokenClassification,
)
from transformers import EsmTokenizer
from .base_tokenizer import BaseSequenceTokenizer


presets = {
    "DPLM2-150": "airkingbd/dplm2_150m",
    "DPLM2-650": "airkingbd/dplm2_650m",
    "DPLM2-3B": "airkingbd/dplm2_3b",
}


class DPLM2TokenizerWrapper(BaseSequenceTokenizer):
    def __init__(self, tokenizer: EsmTokenizer):
        super().__init__(tokenizer)

    def __call__(
        self, sequences: Union[str, List[str]], **kwargs
    ) -> Dict[str, torch.Tensor]:
        if isinstance(sequences, str):
            sequences = [sequences]
        kwargs.setdefault("return_tensors", "pt")
        kwargs.setdefault("padding", "longest")
        kwargs.setdefault("add_special_tokens", True)
        tokenized = self.tokenizer(sequences, **kwargs)
        return tokenized


class DPLM2ForEmbedding(nn.Module):
    def __init__(self, model_path: str, dtype: torch.dtype = None):
        super().__init__()
        self.dplm2 = DPLM2ForMaskedLM.from_pretrained(model_path, dtype=dtype)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = False,
        **kwargs,
    ) -> torch.Tensor:
        out = self.dplm2(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        if output_attentions:
            return out.last_hidden_state, out.attentions
        return out.last_hidden_state


def get_dplm2_tokenizer(preset: str, model_path: str = None):
    return DPLM2TokenizerWrapper(EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D"))


def build_dplm2_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs):
    model_path = model_path or presets[preset]
    if masked_lm:
        model = DPLM2ForMaskedLM.from_pretrained(model_path, dtype=dtype).eval()
    else:
        model = DPLM2ForEmbedding(model_path, dtype=dtype).eval()
    tokenizer = get_dplm2_tokenizer(preset)
    return model, tokenizer


def get_dplm2_for_training(
    preset: str,
    tokenwise: bool = False,
    num_labels: int = None,
    hybrid: bool = False,
    dtype: torch.dtype = None,
    model_path: str = None,
):
    model_path = model_path or presets[preset]
    if hybrid:
        model = DPLM2ForMaskedLM.from_pretrained(model_path, dtype=dtype).eval()
    else:
        if tokenwise:
            model = DPLM2ForTokenClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
        else:
            model = DPLM2ForSequenceClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
    tokenizer = get_dplm2_tokenizer(preset)
    return model, tokenizer


if __name__ == "__main__":
    # py -m src.protify.base_models.dplm2
    model, tokenizer = build_dplm2_model("DPLM2-150")
    print(model)
    print(tokenizer)
    print(tokenizer("MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL"))