File size: 5,628 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
import torch.nn as nn
from typing import Optional, Union, List, Dict, Any
from transformers import T5EncoderModel, AutoTokenizer, T5ForConditionalGeneration

from .base_tokenizer import BaseSequenceTokenizer
from .t5 import T5ForSequenceClassification, T5ForTokenClassification


presets = {
    'ANKH-Base': 'Synthyra/ANKH_base',
    'ANKH-Large': 'Synthyra/ANKH_large',
    'ANKH2-Large': 'Synthyra/ANKH2_large',
}


class ANKHTokenizerWrapper(BaseSequenceTokenizer):
    def __init__(self, tokenizer):
        super().__init__(tokenizer)

    def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]:
        if isinstance(sequences, str):
            sequences = [sequences]
        kwargs.setdefault('return_tensors', 'pt')
        kwargs.setdefault('padding', 'longest')
        kwargs.setdefault('add_special_tokens', True)
        tokenized = self.tokenizer(sequences, **kwargs)
        return tokenized


class AnkhForEmbedding(nn.Module):
    def __init__(self, model_path: str, dtype: torch.dtype = None):
        super().__init__()
        self.plm = T5EncoderModel.from_pretrained(model_path, dtype=dtype)

    def forward(
            self,
            input_ids: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = False,
            **kwargs,
    ) -> torch.Tensor:
        if output_attentions:
            out = self.plm(input_ids, attention_mask=attention_mask, output_attentions=output_attentions)
            return out.last_hidden_state, out.attentions
        else:
            return self.plm(input_ids, attention_mask=attention_mask).last_hidden_state


class AnkhForProteinGym(nn.Module):
    def __init__(self, model_path: str, dtype: torch.dtype = None):
        super().__init__()
        self.plm = T5ForConditionalGeneration.from_pretrained(model_path, dtype=dtype)

    @torch.no_grad()
    def position_log_probs(
        self,
        seq: str,
        pos: int,
        tokenizer: Any,
        device: Optional[torch.device] = None,
        sentinel: str = "<extra_id_0>",
    ) -> torch.Tensor:
        """
        Compute log-probs over the vocab for the single position `pos` in `seq`
        using T5-style span corruption:

        - Encoder input: replace seq[pos] with <extra_id_0>
        - Decoder input: start with <extra_id_0>
        - The logits at the last decoder position correspond to the first token of the span,
          i.e., the masked residue distribution.

        Returns: tensor of shape [vocab_size] (log-probs).
        """
        assert 0 <= pos < len(seq), f"pos {pos} out of range for len={len(seq)}"

        # Resolve device
        if device is None:
            device = next(self.parameters()).device

        # Build encoder ids = tokenized left + sentinel + tokenized right (no spaces).
        left, right = seq[:pos], seq[pos+1:]
        if left:
            left_ids = tokenizer(left, add_special_tokens=False)["input_ids"][0].tolist()
        else:
            left_ids = []
        if right:
            right_ids = tokenizer(right, add_special_tokens=False)["input_ids"][0].tolist()
        else:
            right_ids = []

        sent_id = tokenizer.convert_tokens_to_ids(sentinel)
        if sent_id is None:
            raise ValueError(f"Sentinel token {sentinel} not found in tokenizer.")

        enc_ids = torch.tensor([left_ids + [sent_id] + right_ids], dtype=torch.long, device=device)
        enc_mask = torch.ones_like(enc_ids, device=device)

        # Decoder primed with the SAME sentinel; the next token distribution is what we want.
        dec_ids = torch.tensor([[sent_id]], dtype=torch.long, device=device)

        out = self(
            input_ids=enc_ids,
            attention_mask=enc_mask,
            decoder_input_ids=dec_ids,
            use_cache=False,
            output_hidden_states=False,
            output_attentions=False,
        )
        logits = out.logits  # [1, 1, vocab]
        log_probs = torch.log_softmax(logits[0, -1, :], dim=-1)
        return log_probs


def get_ankh_tokenizer(preset: str, model_path: str = None):
    return ANKHTokenizerWrapper(AutoTokenizer.from_pretrained('Synthyra/ANKH_base'))


def build_ankh_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs):
    model_path = model_path or presets[preset]
    if masked_lm:
        model = T5ForConditionalGeneration.from_pretrained(model_path, dtype=dtype).eval()
    else:
        model = AnkhForEmbedding(model_path, dtype=dtype).eval()
    tokenizer = get_ankh_tokenizer(preset)
    return model, tokenizer


def get_ankh_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype: torch.dtype = None, model_path: str = None):
    model_path = model_path or presets[preset]
    if hybrid:
        model = T5EncoderModel.from_pretrained(model_path, dtype=dtype).eval()
    else:
        if tokenwise:
            model = T5ForTokenClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
        else:
            model = T5ForSequenceClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
    tokenizer = get_ankh_tokenizer(preset)
    return model, tokenizer


if __name__ == '__main__':
    # py -m src.protify.base_models.ankh
    model, tokenizer = build_ankh_model('ANKH-Base')
    print(model)
    print(tokenizer)
    print(tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL'))