| import numpy as np |
| import json |
| import onnxruntime |
| from transformers import BertTokenizer, RobertaTokenizer |
| import torch |
|
|
| def init(): |
| global session, prot_tokenizer, mol_tokenizer, input_name |
| session = onnxruntime.InferenceSession("models/affinity_predictor0734-seed2101.onnx") |
| input_name = session.get_inputs()[0].name |
| prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) |
| mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") |
|
|
| def run(raw_data): |
| try: |
| data = json.loads(raw_data) |
| prot_seq = data['protein'] |
| mol_smiles = data['smiles'] |
|
|
| |
| prot_tokens = prot_tokenizer(preprocess_sequence(prot_seq), |
| padding=True, |
| max_length=3200, |
| truncation=True, |
| return_tensors='pt') |
| with torch.no_grad(): |
| prot_representations = torch.tensor(prot_tokens['input_ids']).unsqueeze(0) |
| prot_representations = prot_representations.squeeze(0) |
|
|
| |
| mol_tokens = mol_tokenizer(mol_smiles, |
| padding=True, |
| max_length=278, |
| truncation=True, |
| return_tensors='pt') |
| with torch.no_grad(): |
| mol_representations = torch.tensor(mol_tokens['input_ids']).unsqueeze(0) |
| mol_representations = mol_representations.squeeze(0) |
|
|
| |
| features = torch.cat((prot_representations, mol_representations), dim=0) |
|
|
| |
| affinity_normalized = session.run(None, {input_name: [features.numpy()], 'TrainingMode': np.array(False)})[0][0][0] |
|
|
| |
| affinity = convert_to_affinity(affinity_normalized) |
|
|
| return (affinity) |
| except Exception as e: |
| return json.dumps({"error": str(e)}) |
|
|
| def preprocess_sequence(seq): |
| import re |
| return " ".join(re.sub(r"[UZOB]", "X", seq)) |
|
|
| def convert_to_affinity(normalized): |
| mean = 6.51286529169358 |
| scale = 1.5614094578916633 |
| return { |
| "neg_log10_affinity_M": (normalized * scale) + mean, |
| "affinity_uM": (10**6) * (10**(-((normalized * scale) + mean))) |
| } |
|
|
| print(run({"protein": "MILK", "smiles": "CCO"})) |