File size: 5,628 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import torch
import torch.nn as nn
from typing import Optional, Union, List, Dict, Any
from transformers import T5EncoderModel, AutoTokenizer, T5ForConditionalGeneration
from .base_tokenizer import BaseSequenceTokenizer
from .t5 import T5ForSequenceClassification, T5ForTokenClassification
presets = {
'ANKH-Base': 'Synthyra/ANKH_base',
'ANKH-Large': 'Synthyra/ANKH_large',
'ANKH2-Large': 'Synthyra/ANKH2_large',
}
class ANKHTokenizerWrapper(BaseSequenceTokenizer):
def __init__(self, tokenizer):
super().__init__(tokenizer)
def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]:
if isinstance(sequences, str):
sequences = [sequences]
kwargs.setdefault('return_tensors', 'pt')
kwargs.setdefault('padding', 'longest')
kwargs.setdefault('add_special_tokens', True)
tokenized = self.tokenizer(sequences, **kwargs)
return tokenized
class AnkhForEmbedding(nn.Module):
def __init__(self, model_path: str, dtype: torch.dtype = None):
super().__init__()
self.plm = T5EncoderModel.from_pretrained(model_path, dtype=dtype)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = False,
**kwargs,
) -> torch.Tensor:
if output_attentions:
out = self.plm(input_ids, attention_mask=attention_mask, output_attentions=output_attentions)
return out.last_hidden_state, out.attentions
else:
return self.plm(input_ids, attention_mask=attention_mask).last_hidden_state
class AnkhForProteinGym(nn.Module):
def __init__(self, model_path: str, dtype: torch.dtype = None):
super().__init__()
self.plm = T5ForConditionalGeneration.from_pretrained(model_path, dtype=dtype)
@torch.no_grad()
def position_log_probs(
self,
seq: str,
pos: int,
tokenizer: Any,
device: Optional[torch.device] = None,
sentinel: str = "<extra_id_0>",
) -> torch.Tensor:
"""
Compute log-probs over the vocab for the single position `pos` in `seq`
using T5-style span corruption:
- Encoder input: replace seq[pos] with <extra_id_0>
- Decoder input: start with <extra_id_0>
- The logits at the last decoder position correspond to the first token of the span,
i.e., the masked residue distribution.
Returns: tensor of shape [vocab_size] (log-probs).
"""
assert 0 <= pos < len(seq), f"pos {pos} out of range for len={len(seq)}"
# Resolve device
if device is None:
device = next(self.parameters()).device
# Build encoder ids = tokenized left + sentinel + tokenized right (no spaces).
left, right = seq[:pos], seq[pos+1:]
if left:
left_ids = tokenizer(left, add_special_tokens=False)["input_ids"][0].tolist()
else:
left_ids = []
if right:
right_ids = tokenizer(right, add_special_tokens=False)["input_ids"][0].tolist()
else:
right_ids = []
sent_id = tokenizer.convert_tokens_to_ids(sentinel)
if sent_id is None:
raise ValueError(f"Sentinel token {sentinel} not found in tokenizer.")
enc_ids = torch.tensor([left_ids + [sent_id] + right_ids], dtype=torch.long, device=device)
enc_mask = torch.ones_like(enc_ids, device=device)
# Decoder primed with the SAME sentinel; the next token distribution is what we want.
dec_ids = torch.tensor([[sent_id]], dtype=torch.long, device=device)
out = self(
input_ids=enc_ids,
attention_mask=enc_mask,
decoder_input_ids=dec_ids,
use_cache=False,
output_hidden_states=False,
output_attentions=False,
)
logits = out.logits # [1, 1, vocab]
log_probs = torch.log_softmax(logits[0, -1, :], dim=-1)
return log_probs
def get_ankh_tokenizer(preset: str, model_path: str = None):
return ANKHTokenizerWrapper(AutoTokenizer.from_pretrained('Synthyra/ANKH_base'))
def build_ankh_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs):
model_path = model_path or presets[preset]
if masked_lm:
model = T5ForConditionalGeneration.from_pretrained(model_path, dtype=dtype).eval()
else:
model = AnkhForEmbedding(model_path, dtype=dtype).eval()
tokenizer = get_ankh_tokenizer(preset)
return model, tokenizer
def get_ankh_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype: torch.dtype = None, model_path: str = None):
model_path = model_path or presets[preset]
if hybrid:
model = T5EncoderModel.from_pretrained(model_path, dtype=dtype).eval()
else:
if tokenwise:
model = T5ForTokenClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
else:
model = T5ForSequenceClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
tokenizer = get_ankh_tokenizer(preset)
return model, tokenizer
if __name__ == '__main__':
# py -m src.protify.base_models.ankh
model, tokenizer = build_ankh_model('ANKH-Base')
print(model)
print(tokenizer)
print(tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL'))
|