DELM / src /models /plm.py
xushijie
add app
21f308b
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
import re
def EsmModelInfo(name: str):
"""Get model info by name:
Args:
name: str, model name
Returns:
dict, model info: dim, layers, model
"""
return {
"esm2_t48_15B_UR50D": {
"dim": 5120,
"layers": 48,
"model": "facebook/esm2_t48_15B_UR50D",
},
"esm2_t36_3B_UR50D": {
"dim": 2560,
"layers": 36,
"model": "facebook/esm2_t36_3B_UR50D",
},
"esm2_t33_650M_UR50D": {
"dim": 1280,
"layers": 33,
"model": "facebook/esm2_t33_650M_UR50D",
},
"esm2_t30_150M_UR50D": {
"dim": 640,
"layers": 30,
"model": "facebook/esm2_t30_150M_UR50D",
},
"esm2_t12_35M_UR50D": {
"dim": 480,
"layers": 12,
"model": "facebook/esm2_t12_35M_UR50D",
},
"esm2_t6_8M_UR50D": {
"dim": 320,
"layers": 6,
"model": "facebook/esm2_t6_8M_UR50D",
},
"esm1b_t33_650M_UR50S": {
"dim": 1280,
"layers": 33,
"model": "facebook/esm1b_t33_650M_UR50S",
},
"prot_t5_xl_half_uniref50-enc": {
"dim": 1024,
"layers": 24,
"model": "Rostlab/prot_t5_xl_uniref50",
},
"prot_t5_xl_bfd": {
"dim": 1024,
"layers": 24,
"model": "Rostlab/prot_t5_xl_bfd",
},
"esmc-6b-2024-12": {
"dim": 2560,
"layers": -1,
"model": "esmc-6b-2024-12",
},
"esmc_300m": {
"dim": 768,
"layers": -1,
"model": "esmc_300m",
},
"esmc_600m": {
"dim": 1152,
"layers": -1,
"model": "esmc_600m",
},
}[name]
plm2abbr = {
'esm2_t48_15B_UR50D': 'ESM2_T48',
'esm2_t36_3B_UR50D': 'ESM2_T36',
'esm2_t33_650M_UR50D': 'ESM2_T33',
'esm2_t30_150M_UR50D': 'ESM2_T30',
'esm2_t12_35M_UR50D': 'ESM2_T12',
'esm2_t6_8M_UR50D': 'ESM2_T6',
'esm1b_t33_650M_UR50S': 'ESM1B_T33',
'prot_t5_xl_half_uniref50-enc': 'PT_UR',
'prot_t5_xl_bfd': 'PT_BFD',
}
class EsmEncoder(nn.Module):
def __init__(self, model_name, dev):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(
# auto, balanced_low_0
model_name,
device_map="balanced",
# torch_dtype=torch.float16,
torch_dtype=torch.float32,
offload_folder=".cache/offload",
offload_state_dict=True,
)
if model_name == "facebook/esm2_t48_15B_UR50D":
self.max_len = 512
else:
self.max_len = 960
self.overlap = 31
self.model.eval()
# self.model.half()
def forward(self, _seqs):
with torch.no_grad():
assert len(_seqs) == 1, "currently only support batch size 1"
seqs = _seqs[0]
# left overlappping, right overlappping
seqs = [
seqs[max(0, i - self.overlap): (i + self.max_len + self.overlap)]
for i in range(0, len(seqs), self.max_len)
]
segs = []
for seq in seqs:
inputs = self.tokenizer(
[seq],
return_tensors="pt",
).to(self.model.device)
outputs = (
self.model(
**inputs).last_hidden_state.squeeze(0).detach().cpu()
)
outputs0 = self.model.embeddings(
**inputs).squeeze(0).detach().cpu()
segs.append(torch.stack([outputs0, outputs], dim=-1))
t = []
for i in range(len(seqs)):
if i == 0:
t.append(segs[i][1: (1 + self.max_len)])
elif i == len(seqs) - 1:
t.append(segs[i][1 + self.overlap:])
else:
t.append(
segs[i][1 + self.overlap: 1 +
self.max_len + self.overlap]
)
outputs = torch.cat(t, dim=0)[: len(_seqs[0])]
assert outputs.shape[0] == len(_seqs[0])
return outputs
class T5Encoder(nn.Module):
def __init__(self, name: str, dev) -> None:
super().__init__()
self.dev = dev
if name == "Rostlab/prot_t5_xl_uniref50":
# Load the tokenizer
self.tokenizer = T5Tokenizer.from_pretrained(
"Rostlab/prot_t5_xl_half_uniref50-enc",
do_lower_case=False,
legacy=False,
)
# Load the model
self.model = T5EncoderModel.from_pretrained(
"Rostlab/prot_t5_xl_half_uniref50-enc"
).to(dev)
elif name == "Rostlab/prot_t5_xl_bfd":
# Load the tokenizer
self.tokenizer = T5Tokenizer.from_pretrained(
"Rostlab/prot_t5_xl_bfd",
do_lower_case=False,
legacy=False,
)
# Load the model
self.model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_bfd").to(
dev
)
self.max_len = 960 # start_token, end_token occupy 2 positions
self.overlap = 31
self.model.eval()
self.model.half()
def forward(self, _seqs):
with torch.no_grad():
assert len(_seqs) == 1, "currently only support batch size 1"
seqs = _seqs[0]
# replace non-amino acids with X
seqs = re.sub(r"[^A-Z]", "X", seqs)
# left overlappping, right overlappping
seqs = [
seqs[max(0, i - self.overlap)
: (i + self.max_len + self.overlap)]
for i in range(0, len(seqs), self.max_len)
]
input_ids = self.tokenizer.batch_encode_plus(
[" ".join(list(s)) for s in seqs],
add_special_tokens=True,
padding="longest",
)["input_ids"]
input_ids = torch.tensor(input_ids).to(self.dev)
outputs = self.model(input_ids=input_ids)
outputs0 = self.model.get_input_embeddings()(input_ids)
outputs = outputs.last_hidden_state
outputs = torch.stack([outputs0, outputs], dim=-1)
t = []
for i in range(len(seqs)):
if i == 0:
t.append(outputs[i, 1: (1 + self.max_len)])
elif i == len(seqs) - 1:
t.append(outputs[i, 1 + self.overlap:])
else:
t.append(
outputs[i, 1 + self.overlap: 1 +
self.max_len + self.overlap]
)
outputs = torch.cat(t, dim=0)[: len(_seqs[0])]
assert outputs.shape[0] == len(_seqs[0]), \
f"outputs shape {outputs.shape} does not match input seqs length {len(_seqs[0])}: {seqs}"
return outputs
def get_model(name: str, dev):
"Get model by name"
if name in [
"esm2_t48_15B_UR50D",
"esm2_t36_3B_UR50D",
"esm2_t33_650M_UR50D",
"esm2_t30_150M_UR50D",
"esm2_t12_35M_UR50D",
"esm2_t6_8M_UR50D",
"esm1b_t33_650M_UR50S",
]:
d = EsmModelInfo(name)
return EsmEncoder(d["model"], dev)
elif name in ["prot_t5_xl_half_uniref50-enc", "prot_t5_xl_bfd"]:
d = EsmModelInfo(name)
return T5Encoder(d["model"], dev)
else:
raise ValueError(f"Unknown model name: {name}")