File size: 3,515 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
We use the FastPLM implementation of DPLM.
"""
import sys
import os
import torch
import torch.nn as nn
from typing import List, Optional, Union, Dict

_FASTPLMS = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'FastPLMs')
if _FASTPLMS not in sys.path:
    sys.path.insert(0, _FASTPLMS)

from dplm_fastplms.modeling_dplm import (
    DPLMForMaskedLM,
    DPLMForSequenceClassification,
    DPLMForTokenClassification,
)
from transformers import EsmTokenizer
from .base_tokenizer import BaseSequenceTokenizer


presets = {
    'DPLM-150': 'airkingbd/dplm_150m',
    'DPLM-650': 'airkingbd/dplm_650m',
    'DPLM-3B': 'airkingbd/dplm_3b',
}


class DPLMTokenizerWrapper(BaseSequenceTokenizer):
    def __init__(self, tokenizer: EsmTokenizer):
        super().__init__(tokenizer)

    def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]:
        if isinstance(sequences, str):
            sequences = [sequences]
        kwargs.setdefault('return_tensors', 'pt')
        kwargs.setdefault('padding', 'longest')
        kwargs.setdefault('add_special_tokens', True)
        tokenized = self.tokenizer(sequences, **kwargs)
        return tokenized


class DPLMForEmbedding(nn.Module):
    def __init__(self, model_path: str, return_logits: bool = False, dtype: torch.dtype = None):
        super().__init__()
        self.dplm = DPLMForMaskedLM.from_pretrained(model_path, dtype=dtype)
        self.return_logits = return_logits

    def forward(
            self,
            input_ids: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = False,
            **kwargs,
    ) -> torch.Tensor:
        if output_attentions:
            out = self.dplm(input_ids, attention_mask=attention_mask, output_attentions=output_attentions)
            return out.last_hidden_state, out.attentions
        out = self.dplm(input_ids, attention_mask=attention_mask)
        if self.return_logits:
            return out.last_hidden_state, out.logits
        return out.last_hidden_state


def get_dplm_tokenizer(preset: str, model_path: str = None):
    return DPLMTokenizerWrapper(EsmTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D'))


def build_dplm_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs):
    model = DPLMForEmbedding(model_path or presets[preset], return_logits=masked_lm, dtype=dtype).eval()
    tokenizer = get_dplm_tokenizer(preset)
    return model, tokenizer


def get_dplm_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype: torch.dtype = None, model_path: str = None):
    model_path = model_path or presets[preset]
    if hybrid:
        model = DPLMForMaskedLM.from_pretrained(model_path, dtype=dtype).eval()
    else:
        if tokenwise:
            model = DPLMForTokenClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
        else:
            model = DPLMForSequenceClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
    tokenizer = get_dplm_tokenizer(preset)
    return model, tokenizer


if __name__ == '__main__':
    # py -m src.protify.base_models.dplm
    model, tokenizer = build_dplm_model('DPLM-150')
    print(model)
    print(tokenizer)
    print(tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL'))