Upload 46 files
Browse files- .gitattributes +3 -0
- tool/ImageAnalysis.py +68 -0
- tool/PCE.py +146 -0
- tool/__init__.py +13 -0
- tool/browsersearch.py +36 -0
- tool/chemspace.py +195 -0
- tool/coder.py +43 -0
- tool/comget/dataset.py +65 -0
- tool/comget/generator.py +206 -0
- tool/comget/model.py +301 -0
- tool/comget/ppcenos.json +1 -0
- tool/comget/ppcenos.pt +3 -0
- tool/comget/utils.py +275 -0
- tool/converters.py +154 -0
- tool/csv_search.py +34 -0
- tool/dap/.gitignore +160 -0
- tool/dap/OSC/test.ckpt +3 -0
- tool/dap/README.md +1 -0
- tool/dap/config/config_hparam.json +26 -0
- tool/dap/config/predict.json +26 -0
- tool/dap/requirements.txt +18 -0
- tool/dap/run.py +124 -0
- tool/dap/screen.py +118 -0
- tool/dap/train.py +454 -0
- tool/dap/util/attention_flow.py +195 -0
- tool/dap/util/attention_plot.py +93 -0
- tool/dap/util/boxplot.py +201 -0
- tool/dap/util/data/bindingdb_kd.tab +3 -0
- tool/dap/util/data/davis.tab +3 -0
- tool/dap/util/emetric.py +59 -0
- tool/dap/util/load_dataset.py +32 -0
- tool/dap/util/make_external_validation.py +28 -0
- tool/dap/util/utils.py +45 -0
- tool/dataset.csv +0 -0
- tool/deepacceptor/RF.py +70 -0
- tool/deepacceptor/deepacceptor.pkl +3 -0
- tool/deepacceptor/dict.json +1 -0
- tool/deepdonor/pm.pkl +3 -0
- tool/deepdonor/sm.pkl +3 -0
- tool/graphconverter.py +33 -0
- tool/orbital.py +94 -0
- tool/pdfreader.py +86 -0
- tool/property.py +220 -0
- tool/rag.py +101 -0
- tool/rag/index.faiss +3 -0
- tool/rag/index.pkl +3 -0
- tool/search.py +156 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tool/dap/util/data/bindingdb_kd.tab filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
tool/dap/util/data/davis.tab filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
tool/rag/index.faiss filter=lfs diff=lfs merge=lfs -text
|
tool/ImageAnalysis.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Sat Oct 26 15:35:19 2024
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 9 |
+
from langchain.tools import BaseTool
|
| 10 |
+
from langchain_openai import ChatOpenAI
|
| 11 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 12 |
+
from langchain.base_language import BaseLanguageModel
|
| 13 |
+
import base64
|
| 14 |
+
from io import BytesIO
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def convert_to_base64(pil_image):
|
| 20 |
+
buffered = BytesIO()
|
| 21 |
+
pil_image.save(buffered, format="PNG")
|
| 22 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 23 |
+
return img_str
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Imageanalysis(BaseTool):
|
| 27 |
+
name: str = "Imageanalysis"
|
| 28 |
+
description: str = (
|
| 29 |
+
"Useful to answer questions according to the image, figure, diagram or graph. "
|
| 30 |
+
"Useful to analysis the information in the image, figure, diagram or graph. "
|
| 31 |
+
"Input query about image/figure/graph/diagram, return the response"
|
| 32 |
+
)
|
| 33 |
+
return_direct: bool = True
|
| 34 |
+
llm: BaseLanguageModel = None
|
| 35 |
+
path : str = None
|
| 36 |
+
|
| 37 |
+
def __init__(self, path):
|
| 38 |
+
super().__init__( )
|
| 39 |
+
self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
|
| 40 |
+
base_url="https://www.dmxapi.com/v1")
|
| 41 |
+
self.path = path
|
| 42 |
+
# api keys
|
| 43 |
+
|
| 44 |
+
def _run(self, query ) -> str:
|
| 45 |
+
try:
|
| 46 |
+
pil_image = Image.open(self.path)
|
| 47 |
+
rgb_im = pil_image.convert('RGB')
|
| 48 |
+
image_b64 = convert_to_base64(pil_image)
|
| 49 |
+
message = HumanMessage(
|
| 50 |
+
content=[
|
| 51 |
+
{"type": "text", "text": query},
|
| 52 |
+
{
|
| 53 |
+
"type": "image_url",
|
| 54 |
+
"image_url": {"url":f"data:image/jpeg;base64,{image_b64}"},
|
| 55 |
+
},
|
| 56 |
+
],)
|
| 57 |
+
response = self.llm.invoke([message])
|
| 58 |
+
return response.content
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
return str(e)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
async def _arun(self, query) -> str:
|
| 65 |
+
"""Use the tool asynchronously."""
|
| 66 |
+
raise NotImplementedError("this tool does not support async")
|
| 67 |
+
|
| 68 |
+
|
tool/PCE.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Created on Wed Sep 11 10:27:20 2024
|
| 5 |
+
|
| 6 |
+
@author: BM109X32G-10GPU-02
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from langchain.tools import BaseTool
|
| 10 |
+
from rdkit import Chem
|
| 11 |
+
from rdkit.Chem import rdMolDescriptors
|
| 12 |
+
from rdkit.Chem import Descriptors
|
| 13 |
+
from .deepacceptor import RF
|
| 14 |
+
from .deepdonor import sm, pm
|
| 15 |
+
from .dap import run, screen
|
| 16 |
+
import pandas as pd
|
| 17 |
+
|
| 18 |
+
class acceptor_predictor(BaseTool):
|
| 19 |
+
name:str = "acceptor_predictor"
|
| 20 |
+
description:str = (
|
| 21 |
+
"Input acceptor SMILES , returns the score of the acceptor."
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
super().__init__()
|
| 26 |
+
def _run(self, smiles: str) -> str:
|
| 27 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 28 |
+
if mol is None:
|
| 29 |
+
return "Invalid SMILES string"
|
| 30 |
+
smiles = Chem.MolToSmiles(mol)
|
| 31 |
+
pce = RF.main( str(smiles) )
|
| 32 |
+
return f'The power conversion efficiency (PCE) is predicted to be {pce} (predicted by DeepAcceptor) '
|
| 33 |
+
|
| 34 |
+
async def _arun(self, smiles: str) -> str:
|
| 35 |
+
"""Use the tool asynchronously."""
|
| 36 |
+
raise NotImplementedError()
|
| 37 |
+
|
| 38 |
+
class donor_predictor(BaseTool):
|
| 39 |
+
name:str = "donor_predictor"
|
| 40 |
+
description:str = (
|
| 41 |
+
"Input donor SMILES , returns the score of the donor."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
super().__init__()
|
| 46 |
+
def _run(self, smiles: str) -> str:
|
| 47 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 48 |
+
if mol is None:
|
| 49 |
+
return "Invalid SMILES string"
|
| 50 |
+
|
| 51 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 52 |
+
if mol is None:
|
| 53 |
+
return "Invalid SMILES string"
|
| 54 |
+
sdpce = sm.main( str(smiles) )
|
| 55 |
+
pdpce = pm.main( str(smiles) )
|
| 56 |
+
return f'The power conversion efficiency (PCE) of the given molecule is predicted to be {sdpce} as a small molecule donor , and {pdpce} as a polymer donor(predicted by DeepDonor) '
|
| 57 |
+
|
| 58 |
+
async def _arun(self, smiles: str) -> str:
|
| 59 |
+
"""Use the tool asynchronously."""
|
| 60 |
+
raise NotImplementedError()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class dap_predictor(BaseTool):
|
| 65 |
+
name:str = "dap_predictor"
|
| 66 |
+
description :str = (
|
| 67 |
+
"Input SMILES of D/A pairs(separated by '.') , returns the performance of the D/A pairs ."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def __init__(self):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
def _run(self, smiles_pair: str) -> str:
|
| 75 |
+
smi_list = smiles_pair.split(".")
|
| 76 |
+
if len(smi_list) != 2:
|
| 77 |
+
|
| 78 |
+
return "Input error, please input two smiles strings separated by '.'"
|
| 79 |
+
|
| 80 |
+
else:
|
| 81 |
+
smiles1, smiles2 = smi_list
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
pce = run.smiles_aas_test( str(smiles1 ), str(smiles2) )
|
| 85 |
+
|
| 86 |
+
return pce
|
| 87 |
+
|
| 88 |
+
async def _arun(self, smiles_pair: str) -> str:
|
| 89 |
+
"""Use the tool asynchronously."""
|
| 90 |
+
raise NotImplementedError()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class dap_screen(BaseTool):
|
| 95 |
+
name:str = "dap_screen"
|
| 96 |
+
description :str = (
|
| 97 |
+
"Input dataset path containing D/A pairs, returns the files of prediction results."
|
| 98 |
+
)
|
| 99 |
+
return_direct: bool = True
|
| 100 |
+
def __init__(self):
|
| 101 |
+
super().__init__()
|
| 102 |
+
|
| 103 |
+
def _run(self, file_path: str) -> str:
|
| 104 |
+
smi_list = screen.smiles_aas_test(file_path)
|
| 105 |
+
|
| 106 |
+
return smi_list
|
| 107 |
+
|
| 108 |
+
async def _arun(self, smiles_pair: str) -> str:
|
| 109 |
+
"""Use the tool asynchronously."""
|
| 110 |
+
raise NotImplementedError()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
from .comget import generator
|
| 114 |
+
|
| 115 |
+
class molgen(BaseTool):
|
| 116 |
+
name: str = "donorgen"
|
| 117 |
+
description: str = (
|
| 118 |
+
|
| 119 |
+
"Useful to generate polymer donor molecules with required PCE. "
|
| 120 |
+
"Input the values of PCE , return the SMILES"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def __init__(self
|
| 125 |
+
):
|
| 126 |
+
super().__init__( )
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _run(self, value ) -> str:
|
| 130 |
+
try:
|
| 131 |
+
results = generator.generation(value)
|
| 132 |
+
for i in results['smiles']:
|
| 133 |
+
pdpce = pm.main( str(i) )
|
| 134 |
+
if abs(pdpce-float(value))<1.0:
|
| 135 |
+
return f"The SMILES of generated donor is {i}, its predicted PCE is {pdpce}."
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
return str(e)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
async def _arun(self, query) -> str:
|
| 145 |
+
"""Use the tool asynchronously."""
|
| 146 |
+
raise NotImplementedError("this tool does not support async")
|
tool/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""load all tools."""
|
| 2 |
+
|
| 3 |
+
from .coder import *
|
| 4 |
+
from .property import *
|
| 5 |
+
from .search import *
|
| 6 |
+
from .PCE import *
|
| 7 |
+
from .converters import *
|
| 8 |
+
from .orbital import *
|
| 9 |
+
from .graphconverter import *
|
| 10 |
+
from .ImageAnalysis import *
|
| 11 |
+
from .pdfreader import *
|
| 12 |
+
from .rag import *
|
| 13 |
+
from .browsersearch import *
|
tool/browsersearch.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_openai import ChatOpenAI
|
| 2 |
+
from browser_use import Agent
|
| 3 |
+
import asyncio
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
load_dotenv()
|
| 6 |
+
from langchain.tools import BaseTool
|
| 7 |
+
|
| 8 |
+
async def main(task):
|
| 9 |
+
agent = Agent(
|
| 10 |
+
task=task,
|
| 11 |
+
llm=ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
|
| 12 |
+
base_url="https://www.dmxapi.com/v1"),
|
| 13 |
+
)
|
| 14 |
+
result = await agent.run()
|
| 15 |
+
return result
|
| 16 |
+
|
| 17 |
+
class browseruse(BaseTool):
|
| 18 |
+
name: str = "browseruse"
|
| 19 |
+
description: str = ("Calling the browser to search for information in specific website"
|
| 20 |
+
"input query, return the searching results")
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
def _run(self, task: str) -> str:
|
| 28 |
+
result = asyncio.run(main(task))
|
| 29 |
+
return result
|
| 30 |
+
|
| 31 |
+
async def _arun(self, smiles: str) -> str:
|
| 32 |
+
"""Use the tool asynchronously."""
|
| 33 |
+
raise NotImplementedError()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
tool/chemspace.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import molbloom
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import requests
|
| 6 |
+
from langchain.tools import BaseTool
|
| 7 |
+
|
| 8 |
+
from utils import is_smiles
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ChemSpace:
|
| 12 |
+
def __init__(self, chemspace_api_key=None):
|
| 13 |
+
self.chemspace_api_key = chemspace_api_key
|
| 14 |
+
self._renew_token() # Create token
|
| 15 |
+
|
| 16 |
+
def _renew_token(self):
|
| 17 |
+
self.chemspace_token = requests.get(
|
| 18 |
+
url="https://api.chem-space.com/auth/token",
|
| 19 |
+
headers={
|
| 20 |
+
"Accept": "application/json",
|
| 21 |
+
"Authorization": f"Bearer {self.chemspace_api_key}",
|
| 22 |
+
},
|
| 23 |
+
).json()["access_token"]
|
| 24 |
+
|
| 25 |
+
def _make_api_request(
|
| 26 |
+
self,
|
| 27 |
+
query,
|
| 28 |
+
request_type,
|
| 29 |
+
count,
|
| 30 |
+
categories,
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Make a generic request to chem-space API.
|
| 34 |
+
|
| 35 |
+
Categories request.
|
| 36 |
+
CSCS: Custom Request: Could be useful for requesting whole synthesis
|
| 37 |
+
CSMB: Make-On-Demand Building Blocks
|
| 38 |
+
CSSB: In-Stock Building Blocks
|
| 39 |
+
CSSS: In-stock Screening Compounds
|
| 40 |
+
CSMS: Make-On-Demand Screening Compounds
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def _do_request():
|
| 44 |
+
data = requests.request(
|
| 45 |
+
"POST",
|
| 46 |
+
url=f"https://api.chem-space.com/v3/search/{request_type}?count={count}&page=1&categories={categories}",
|
| 47 |
+
headers={
|
| 48 |
+
"Accept": "application/json; version=3.1",
|
| 49 |
+
"Authorization": f"Bearer {self.chemspace_token}",
|
| 50 |
+
},
|
| 51 |
+
data={"SMILES": f"{query}"},
|
| 52 |
+
).json()
|
| 53 |
+
return data
|
| 54 |
+
|
| 55 |
+
data = _do_request()
|
| 56 |
+
|
| 57 |
+
# renew token if token is invalid
|
| 58 |
+
if "message" in data.keys():
|
| 59 |
+
if data["message"] == "Your request was made with invalid credentials.":
|
| 60 |
+
self._renew_token()
|
| 61 |
+
|
| 62 |
+
data = _do_request()
|
| 63 |
+
return data
|
| 64 |
+
|
| 65 |
+
def _convert_single(self, query, search_type: str):
|
| 66 |
+
"""Do query for a single molecule"""
|
| 67 |
+
data = self._make_api_request(query, "exact", 1, "CSCS,CSMB,CSSB")
|
| 68 |
+
if data["count"] > 0:
|
| 69 |
+
return data["items"][0][search_type]
|
| 70 |
+
else:
|
| 71 |
+
return "No data was found for this compound."
|
| 72 |
+
|
| 73 |
+
def convert_mol_rep(self, query, search_type: str = "smiles"):
|
| 74 |
+
if ", " in query:
|
| 75 |
+
query_list = query.split(", ")
|
| 76 |
+
else:
|
| 77 |
+
query_list = [query]
|
| 78 |
+
smi = ""
|
| 79 |
+
try:
|
| 80 |
+
for q in query_list:
|
| 81 |
+
smi += f"{query}'s {search_type} is: {str(self._convert_single(q, search_type))}"
|
| 82 |
+
return smi
|
| 83 |
+
except Exception:
|
| 84 |
+
return "The input provided is wrong. Input either a single molecule, or multiple molecules separated by a ', '"
|
| 85 |
+
|
| 86 |
+
def buy_mol(
|
| 87 |
+
self,
|
| 88 |
+
smiles,
|
| 89 |
+
request_type="exact",
|
| 90 |
+
count=1,
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Get data about purchasing compounds.
|
| 94 |
+
|
| 95 |
+
smiles: smiles string of the molecule you want to buy
|
| 96 |
+
request_type: one of "exact", "sim" (search by similarity), "sub" (search by substructure).
|
| 97 |
+
count: retrieve data for this many substances max.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def purchasable_check(
|
| 101 |
+
s,
|
| 102 |
+
):
|
| 103 |
+
if not is_smiles(s):
|
| 104 |
+
try:
|
| 105 |
+
s = self.convert_mol_rep(s, "smiles")
|
| 106 |
+
except:
|
| 107 |
+
return "Invalid SMILES string."
|
| 108 |
+
|
| 109 |
+
"""Checks if molecule is available for purchase (ZINC20)"""
|
| 110 |
+
try:
|
| 111 |
+
r = molbloom.buy(s, canonicalize=True)
|
| 112 |
+
except:
|
| 113 |
+
print("invalid smiles")
|
| 114 |
+
return False
|
| 115 |
+
if r:
|
| 116 |
+
return True
|
| 117 |
+
else:
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
purchasable = purchasable_check(smiles)
|
| 121 |
+
|
| 122 |
+
if request_type == "exact":
|
| 123 |
+
categories = "CSMB,CSSB"
|
| 124 |
+
elif request_type in ["sim", "sub"]:
|
| 125 |
+
categories = "CSSS,CSMS"
|
| 126 |
+
|
| 127 |
+
data = self._make_api_request(smiles, request_type, count, categories)
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
if data["count"] == 0:
|
| 131 |
+
if purchasable:
|
| 132 |
+
return "Compound is purchasable, but price is unknown."
|
| 133 |
+
else:
|
| 134 |
+
return "Compound is not purchasable."
|
| 135 |
+
except KeyError:
|
| 136 |
+
return "Invalid query, try something else. "
|
| 137 |
+
|
| 138 |
+
print(f"Obtaining data for {data['count']} substances.")
|
| 139 |
+
|
| 140 |
+
dfs = []
|
| 141 |
+
# Convert this data into df
|
| 142 |
+
for item in data["items"]:
|
| 143 |
+
dfs_tmp = []
|
| 144 |
+
smiles = item["smiles"]
|
| 145 |
+
offers = item["offers"]
|
| 146 |
+
|
| 147 |
+
for off in offers:
|
| 148 |
+
df_tmp = pd.DataFrame(off["prices"])
|
| 149 |
+
df_tmp["vendorName"] = off["vendorName"]
|
| 150 |
+
df_tmp["time"] = off["shipsWithin"]
|
| 151 |
+
df_tmp["purity"] = off["purity"]
|
| 152 |
+
|
| 153 |
+
dfs_tmp.append(df_tmp)
|
| 154 |
+
|
| 155 |
+
df_this = pd.concat(dfs_tmp)
|
| 156 |
+
df_this["smiles"] = smiles
|
| 157 |
+
dfs.append(df_this)
|
| 158 |
+
|
| 159 |
+
df = pd.concat(dfs).reset_index(drop=True)
|
| 160 |
+
|
| 161 |
+
df["quantity"] = df["pack"].astype(str) + df["uom"]
|
| 162 |
+
df["time"] = df["time"].astype(str) + " days"
|
| 163 |
+
|
| 164 |
+
df = df.drop(columns=["pack", "uom"])
|
| 165 |
+
# Remove all entries that are not numbers
|
| 166 |
+
df = df[df["priceUsd"].astype(str).str.isnumeric()]
|
| 167 |
+
|
| 168 |
+
cheapest = df.iloc[df["priceUsd"].astype(float).idxmin()]
|
| 169 |
+
return f"{cheapest['quantity']} of this molecule cost {cheapest['priceUsd']} USD and can be purchased at {cheapest['vendorName']}."
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class GetMoleculePrice(BaseTool):
|
| 173 |
+
name :str = "GetMoleculePrice"
|
| 174 |
+
description :str = "Get the cheapest available price of a molecule."
|
| 175 |
+
chemspace_api_key: str = None
|
| 176 |
+
url: str = None
|
| 177 |
+
|
| 178 |
+
def __init__(self, chemspace_api_key: str = None):
|
| 179 |
+
super().__init__()
|
| 180 |
+
self.chemspace_api_key = chemspace_api_key
|
| 181 |
+
self.url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}"
|
| 182 |
+
|
| 183 |
+
def _run(self, query: str) -> str:
|
| 184 |
+
if not self.chemspace_api_key:
|
| 185 |
+
return "No Chemspace API key found. This tool may not be used without a Chemspace API key."
|
| 186 |
+
try:
|
| 187 |
+
chemspace = ChemSpace(self.chemspace_api_key)
|
| 188 |
+
price = chemspace.buy_mol(query)
|
| 189 |
+
return price
|
| 190 |
+
except Exception as e:
|
| 191 |
+
return str(e)
|
| 192 |
+
|
| 193 |
+
async def _arun(self, query: str) -> str:
|
| 194 |
+
"""Use the tool asynchronously."""
|
| 195 |
+
raise NotImplementedError()
|
tool/coder.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Sat Oct 26 15:35:19 2024
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
| 9 |
+
from langchain.tools import BaseTool
|
| 10 |
+
from langchain_openai import ChatOpenAI
|
| 11 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 12 |
+
from langchain.base_language import BaseLanguageModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class codewriter(BaseTool):
|
| 16 |
+
name:str = "codewriter"
|
| 17 |
+
description:str = (
|
| 18 |
+
"Useful to answer questions that require writing codes "
|
| 19 |
+
"return the usage and instruction of codes"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
llm: BaseLanguageModel = None
|
| 23 |
+
def __init__(self):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
|
| 26 |
+
base_url="https://www.dmxapi.com/v1")
|
| 27 |
+
# api keys
|
| 28 |
+
|
| 29 |
+
def _run(self, query) -> str:
|
| 30 |
+
messages = [
|
| 31 |
+
SystemMessage(content="You are an expert at writing code, write the corresponding code based on the inputs"),
|
| 32 |
+
HumanMessage(content=query),
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
response = self.llm.invoke(messages)
|
| 36 |
+
return response
|
| 37 |
+
|
| 38 |
+
async def _arun(self, query) -> str:
|
| 39 |
+
"""Use the tool asynchronously."""
|
| 40 |
+
raise NotImplementedError("this tool does not support async")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
tool/comget/dataset.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
from utils import SmilesEnumerator
|
| 4 |
+
import numpy as np
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
class SmileDataset(Dataset):
|
| 8 |
+
|
| 9 |
+
def __init__(self, args, data, content, block_size, aug_prob = 0.5, prop = None, scaffold = None, scaffold_maxlen = None):
|
| 10 |
+
chars = sorted(list(set(content)))
|
| 11 |
+
data_size, vocab_size = len(data), len(chars)
|
| 12 |
+
print('data has %d smiles, %d unique characters.' % (data_size, vocab_size))
|
| 13 |
+
|
| 14 |
+
self.stoi = { ch:i for i,ch in enumerate(chars) }
|
| 15 |
+
self.itos = { i:ch for i,ch in enumerate(chars) }
|
| 16 |
+
self.max_len = block_size
|
| 17 |
+
self.vocab_size = vocab_size
|
| 18 |
+
self.data = data
|
| 19 |
+
self.prop = prop
|
| 20 |
+
self.sca = scaffold
|
| 21 |
+
self.scaf_max_len = scaffold_maxlen
|
| 22 |
+
self.debug = args.debug
|
| 23 |
+
self.tfm = SmilesEnumerator()
|
| 24 |
+
self.aug_prob = aug_prob
|
| 25 |
+
|
| 26 |
+
def __len__(self):
|
| 27 |
+
if self.debug:
|
| 28 |
+
return math.ceil(len(self.data) / (self.max_len + 1))
|
| 29 |
+
else:
|
| 30 |
+
return len(self.data)
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
smiles, prop, scaffold = self.data[idx], self.prop[idx], self.sca[idx] # self.prop.iloc[idx, :].values --> if multiple properties
|
| 34 |
+
smiles = smiles.strip()
|
| 35 |
+
scaffold = scaffold.strip()
|
| 36 |
+
|
| 37 |
+
p = np.random.uniform()
|
| 38 |
+
if p < self.aug_prob:
|
| 39 |
+
smiles = self.tfm.randomize_smiles(smiles)
|
| 40 |
+
|
| 41 |
+
pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
|
| 42 |
+
regex = re.compile(pattern)
|
| 43 |
+
smiles += str('<')*(self.max_len - len(regex.findall(smiles)))
|
| 44 |
+
|
| 45 |
+
if len(regex.findall(smiles)) > self.max_len:
|
| 46 |
+
smiles = smiles[:self.max_len]
|
| 47 |
+
|
| 48 |
+
smiles=regex.findall(smiles)
|
| 49 |
+
|
| 50 |
+
scaffold += str('<')*(self.scaf_max_len - len(regex.findall(scaffold)))
|
| 51 |
+
|
| 52 |
+
if len(regex.findall(scaffold)) > self.scaf_max_len:
|
| 53 |
+
scaffold = scaffold[:self.scaf_max_len]
|
| 54 |
+
|
| 55 |
+
scaffold=regex.findall(scaffold)
|
| 56 |
+
|
| 57 |
+
dix = [self.stoi[s] for s in smiles]
|
| 58 |
+
sca_dix = [self.stoi[s] for s in scaffold]
|
| 59 |
+
|
| 60 |
+
sca_tensor = torch.tensor(sca_dix, dtype=torch.long)
|
| 61 |
+
x = torch.tensor(dix[:-1], dtype=torch.long)
|
| 62 |
+
y = torch.tensor(dix[1:], dtype=torch.long)
|
| 63 |
+
# prop = torch.tensor([prop], dtype=torch.long)
|
| 64 |
+
prop = torch.tensor([prop], dtype = torch.float)
|
| 65 |
+
return x, y, prop, sca_tensor
|
tool/comget/generator.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import argparse
|
| 7 |
+
from .model import GPT, GPTConfig
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
#import seaborn as sns
|
| 13 |
+
from .moses.utils import get_mol
|
| 14 |
+
import re
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
from rdkit.Chem import RDConfig
|
| 18 |
+
|
| 19 |
+
import selfies as sf
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
| 23 |
+
from .utils import sample, canonic_smiles
|
| 24 |
+
import sascorer
|
| 25 |
+
from rdkit import Chem
|
| 26 |
+
from rdkit.Chem.rdMolDescriptors import CalcTPSA
|
| 27 |
+
import os
|
| 28 |
+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 29 |
+
|
| 30 |
+
def get_selfie_and_smiles_encodings_for_dataset(smiles):
|
| 31 |
+
"""
|
| 32 |
+
Returns encoding, alphabet and length of largest molecule in SMILES and
|
| 33 |
+
SELFIES, given a file containing SMILES molecules.
|
| 34 |
+
|
| 35 |
+
input:
|
| 36 |
+
csv file with molecules. Column's name must be 'smiles'.
|
| 37 |
+
output:
|
| 38 |
+
- selfies encoding
|
| 39 |
+
- selfies alphabet
|
| 40 |
+
- longest selfies string
|
| 41 |
+
- smiles encoding (equivalent to file content)
|
| 42 |
+
- smiles alphabet (character based)
|
| 43 |
+
- longest smiles string
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
smiles_list = np.asanyarray(smiles)
|
| 47 |
+
|
| 48 |
+
smiles_alphabet = list(set("".join(smiles_list)))
|
| 49 |
+
smiles_alphabet.append(" ") # for padding
|
| 50 |
+
|
| 51 |
+
largest_smiles_len = len(max(smiles_list, key=len))
|
| 52 |
+
|
| 53 |
+
print("--> Translating SMILES to SELFIES...")
|
| 54 |
+
selfies_list = list(map(sf.encoder, smiles_list))
|
| 55 |
+
|
| 56 |
+
all_selfies_symbols = sf.get_alphabet_from_selfies(selfies_list)
|
| 57 |
+
all_selfies_symbols.add("[nop]")
|
| 58 |
+
selfies_alphabet = list(all_selfies_symbols)
|
| 59 |
+
|
| 60 |
+
largest_selfies_len = max(sf.len_selfies(s) for s in selfies_list)
|
| 61 |
+
|
| 62 |
+
print("Finished translating SMILES to SELFIES.")
|
| 63 |
+
|
| 64 |
+
return selfies_list, selfies_alphabet, largest_selfies_len, \
|
| 65 |
+
smiles_list, smiles_alphabet, largest_smiles_len
|
| 66 |
+
|
| 67 |
+
def generation(value):
|
| 68 |
+
parser = argparse.ArgumentParser()
|
| 69 |
+
#parser.add_argument('--model_weight', type=str, help="path of model weights", required=True)
|
| 70 |
+
parser.add_argument('--scaffold', action='store_true', default=False, help='condition on scaffold')
|
| 71 |
+
parser.add_argument('--lstm', action='store_true', default=False, help='use lstm for transforming scaffold')
|
| 72 |
+
#parser.add_argument('--csv_name', type=str, help="name to save the generated mols in csv format", required=True)
|
| 73 |
+
parser.add_argument('--data_name', type=str, default = 'moses2', help="name of the dataset to train on", required=False)
|
| 74 |
+
parser.add_argument('--batch_size', type=int, default = 512, help="batch size", required=False)
|
| 75 |
+
parser.add_argument('--gen_size', type=int, default = 10000, help="number of times to generate from a batch", required=False)
|
| 76 |
+
parser.add_argument('--vocab_size', type=int, default = 26, help="number of layers", required=False) # previously 28 .... 26 for moses. 94 for guacamol
|
| 77 |
+
parser.add_argument('--block_size', type=int, default = 54, help="number of layers", required=False) # previously 57... 54 for moses. 100 for guacamol.
|
| 78 |
+
# parser.add_argument('--num_props', type=int, default = 0, help="number of properties to use for condition", required=False)
|
| 79 |
+
parser.add_argument('--props', nargs="+", default = [], help="properties to be used for condition", required=False)
|
| 80 |
+
parser.add_argument('--n_layer', type=int, default = 8, help="number of layers", required=False)
|
| 81 |
+
parser.add_argument('--n_head', type=int, default = 8, help="number of heads", required=False)
|
| 82 |
+
parser.add_argument('--n_embd', type=int, default = 256, help="embedding dimension", required=False)
|
| 83 |
+
parser.add_argument('--lstm_layers', type=int, default = 2, help="number of layers in lstm", required=False)
|
| 84 |
+
|
| 85 |
+
args = parser.parse_args()
|
| 86 |
+
args.data_name = 'ppcenos'
|
| 87 |
+
args.vocab_size = 29 #
|
| 88 |
+
args.block_size = 196 #max_len
|
| 89 |
+
args.gen_size = 20
|
| 90 |
+
args.batch_size = 5
|
| 91 |
+
args.csv_name = 'ppcenos'
|
| 92 |
+
args.props = ['pce']
|
| 93 |
+
context = "[C]"
|
| 94 |
+
args.scaffold = False
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
|
| 98 |
+
regex = re.compile(pattern)
|
| 99 |
+
|
| 100 |
+
if ('moses' in args.data_name) and args.scaffold:
|
| 101 |
+
scaffold_max_len=48
|
| 102 |
+
elif ('guacamol' in args.data_name):
|
| 103 |
+
scaffold_max_len = 107
|
| 104 |
+
else:
|
| 105 |
+
scaffold_max_len = 181
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
stoi = json.load(open('tool/comget/' + f'{args.data_name}.json', 'r'))
|
| 109 |
+
|
| 110 |
+
# itos = { i:ch for i,ch in enumerate(chars) }
|
| 111 |
+
itos = { i:ch for ch,i in stoi.items() }
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
print(len(itos))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
num_props = len(args.props)
|
| 118 |
+
mconf = GPTConfig(args.vocab_size, args.block_size, num_props = num_props,
|
| 119 |
+
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, scaffold = args.scaffold, scaffold_maxlen = scaffold_max_len,
|
| 120 |
+
lstm = args.lstm, lstm_layers = args.lstm_layers)
|
| 121 |
+
model = GPT(mconf)
|
| 122 |
+
|
| 123 |
+
args.model_weight = f'{args.csv_name}.pt'
|
| 124 |
+
model.load_state_dict(torch.load('tool/comget/' + args.model_weight))
|
| 125 |
+
model.to('cuda')
|
| 126 |
+
print('Model loaded')
|
| 127 |
+
|
| 128 |
+
gen_iter = math.ceil(args.gen_size / args.batch_size)
|
| 129 |
+
# gen_iter = 2
|
| 130 |
+
|
| 131 |
+
if 'guacamol1' in args.data_name:
|
| 132 |
+
prop2value = {'qed': [0.3, 0.5, 0.7], 'sas': [2.0, 3.0, 4.0], 'logp': [2.0, 4.0, 6.0], 'tpsa': [40.0, 80.0, 120.0],
|
| 133 |
+
'tpsa_logp': [[40.0, 2.0], [80.0, 2.0], [120.0, 2.0], [40.0, 4.0], [80.0, 4.0], [120.0, 4.0], [40.0, 6.0], [80.0, 6.0], [120.0, 6.0]],
|
| 134 |
+
'sas_logp': [[2.0, 2.0], [2.0, 4.0], [2.0, 6.0], [3.0, 2.0], [3.0, 4.0], [3.0, 6.0], [4.0, 2.0], [4.0, 4.0], [4.0, 6.0]],
|
| 135 |
+
'tpsa_sas': [[40.0, 2.0], [80.0, 2.0], [120.0, 2.0], [40.0, 3.0], [80.0, 3.0], [120.0, 3.0], [40.0, 4.0], [80.0, 4.0], [120.0, 4.0]],
|
| 136 |
+
'tpsa_logp_sas': [[40.0, 2.0, 2.0], [40.0, 2.0, 4.0], [40.0, 6.0, 4.0], [40.0, 6.0, 2.0], [80.0, 6.0, 4.0], [80.0, 2.0, 4.0], [80.0, 2.0, 2.0], [80.0, 6.0, 2.0]]}
|
| 137 |
+
else:
|
| 138 |
+
prop2value = { 'pce': [float(value)]}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
prop_condition = None
|
| 142 |
+
if len(args.props) > 0:
|
| 143 |
+
prop_condition = prop2value['_'.join(args.props)]
|
| 144 |
+
|
| 145 |
+
scaf_condition = None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
all_dfs = []
|
| 149 |
+
all_metrics = []
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
count = 0
|
| 153 |
+
|
| 154 |
+
if prop_condition is not None and scaf_condition is None :
|
| 155 |
+
|
| 156 |
+
for c in prop_condition:
|
| 157 |
+
molecules = []
|
| 158 |
+
selfies = []
|
| 159 |
+
count += 1
|
| 160 |
+
for i in tqdm(range(gen_iter)):
|
| 161 |
+
x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(args.batch_size, 1).to('cuda')
|
| 162 |
+
p = None
|
| 163 |
+
if len(args.props) == 1:
|
| 164 |
+
p = torch.tensor([c]).repeat(args.batch_size, 1).to('cuda') # for single condition
|
| 165 |
+
else:
|
| 166 |
+
p = torch.tensor([c]).repeat(args.batch_size, 1).unsqueeze(1).to('cuda') # for multiple conditions
|
| 167 |
+
sca = None
|
| 168 |
+
y = sample(model, x, 300, temperature= 1.0, sample=True, top_k = 10, prop = p, scaffold = sca)
|
| 169 |
+
for gen_mol in y:
|
| 170 |
+
completion = ''.join([itos[int(i)] for i in gen_mol])
|
| 171 |
+
completion = completion.replace('<', '')
|
| 172 |
+
selfies.append(completion)
|
| 173 |
+
file = pd.DataFrame(selfies)
|
| 174 |
+
|
| 175 |
+
for ind, i in enumerate( file[0]):
|
| 176 |
+
|
| 177 |
+
smi = (sf.decoder(eval(repr(i))))
|
| 178 |
+
mol = get_mol(smi)
|
| 179 |
+
# gen_smiles.append(completion)
|
| 180 |
+
|
| 181 |
+
if mol:
|
| 182 |
+
|
| 183 |
+
molecules.append(mol)
|
| 184 |
+
else:
|
| 185 |
+
print(ind)
|
| 186 |
+
print(i)
|
| 187 |
+
|
| 188 |
+
"Valid molecules % = {}".format(len(molecules))
|
| 189 |
+
|
| 190 |
+
mol_dict = []
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
for i in molecules:
|
| 194 |
+
mol_dict.append({'molecule' : i, 'smiles': Chem.MolToSmiles(i)})
|
| 195 |
+
|
| 196 |
+
# for i in gen_smiles:
|
| 197 |
+
# mol_dict.append({'temperature' : temp, 'smiles': i})
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
results = pd.DataFrame(mol_dict)
|
| 201 |
+
|
| 202 |
+
all_dfs.append(results)
|
| 203 |
+
|
| 204 |
+
results = pd.concat(all_dfs)
|
| 205 |
+
|
| 206 |
+
return results
|
tool/comget/model.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT model:
|
| 3 |
+
- the initial stem consists of a combination of token encoding and a positional encoding
|
| 4 |
+
- the meat of it is a uniform sequence of Transformer blocks
|
| 5 |
+
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
| 6 |
+
- all blocks feed into a central residual pathway similar to resnets
|
| 7 |
+
- the final decoder is a linear projection into a vanilla Softmax classifier
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
class GPTConfig:
|
| 20 |
+
""" base GPT config, params common to all GPT versions """
|
| 21 |
+
embd_pdrop = 0.1
|
| 22 |
+
resid_pdrop = 0.1
|
| 23 |
+
attn_pdrop = 0.1
|
| 24 |
+
|
| 25 |
+
def __init__(self, vocab_size, block_size, **kwargs):
|
| 26 |
+
self.vocab_size = vocab_size
|
| 27 |
+
self.block_size = block_size
|
| 28 |
+
for k,v in kwargs.items():
|
| 29 |
+
setattr(self, k, v)
|
| 30 |
+
|
| 31 |
+
class GPT1Config(GPTConfig):
|
| 32 |
+
""" GPT-1 like network roughly 125M params """
|
| 33 |
+
n_layer = 12
|
| 34 |
+
n_head = 12
|
| 35 |
+
n_embd = 768
|
| 36 |
+
|
| 37 |
+
class RMSNorm(nn.Module):
|
| 38 |
+
"""Root Mean Square Layer Normalization.
|
| 39 |
+
|
| 40 |
+
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
| 41 |
+
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.scale = nn.Parameter(torch.ones(size))
|
| 47 |
+
self.eps = eps
|
| 48 |
+
self.dim = dim
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
# NOTE: the original RMSNorm paper implementation is not equivalent
|
| 52 |
+
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
|
| 53 |
+
# rms_x = norm_x * d_x ** (-1. / 2)
|
| 54 |
+
# x_normed = x / (rms_x + self.eps)
|
| 55 |
+
# keep RMSNorm in float32
|
| 56 |
+
norm_x = x.to(torch.float32).pow(2).mean(dim=self.dim, keepdim=True)
|
| 57 |
+
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
| 58 |
+
return (self.scale * x_normed).type_as(x)
|
| 59 |
+
|
| 60 |
+
class CausalSelfAttention(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
A vanilla multi-head masked self-attention layer with a projection at the end.
|
| 63 |
+
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
| 64 |
+
explicit implementation here to show that there is nothing too scary here.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, config):
|
| 68 |
+
super().__init__()
|
| 69 |
+
assert config.n_embd % config.n_head == 0
|
| 70 |
+
# key, query, value projections for all heads
|
| 71 |
+
self.key = nn.Linear(config.n_embd, config.n_embd)
|
| 72 |
+
self.query = nn.Linear(config.n_embd, config.n_embd)
|
| 73 |
+
self.value = nn.Linear(config.n_embd, config.n_embd)
|
| 74 |
+
self.q_proj = nn.Linear(
|
| 75 |
+
config.n_embd ,
|
| 76 |
+
config.n_embd ,
|
| 77 |
+
bias=False,
|
| 78 |
+
)
|
| 79 |
+
# key, value projections
|
| 80 |
+
self.kv_proj = nn.Linear(
|
| 81 |
+
config.n_embd ,
|
| 82 |
+
2 * config.n_embd ,
|
| 83 |
+
bias=False,
|
| 84 |
+
)
|
| 85 |
+
# output projection
|
| 86 |
+
self.c_proj = nn.Linear(
|
| 87 |
+
config.n_embd ,
|
| 88 |
+
config.n_embd ,
|
| 89 |
+
bias=False,
|
| 90 |
+
)
|
| 91 |
+
# regularization
|
| 92 |
+
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
| 93 |
+
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
| 94 |
+
# output projection
|
| 95 |
+
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
| 96 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
| 97 |
+
num = int(bool(config.num_props)) + int(config.scaffold_maxlen) #int(config.lstm_layers) # int(config.scaffold)
|
| 98 |
+
# num = 1
|
| 99 |
+
self.register_buffer("mask", torch.tril(torch.ones(config.block_size + num, config.block_size + num))
|
| 100 |
+
.view(1, 1, config.block_size + num, config.block_size + num))
|
| 101 |
+
|
| 102 |
+
self.n_head = config.n_head
|
| 103 |
+
self.n_embd = config.n_embd
|
| 104 |
+
|
| 105 |
+
def forward(self, x, layer_past=None):
|
| 106 |
+
B, T, C = x.size()
|
| 107 |
+
|
| 108 |
+
q = self.q_proj(x)
|
| 109 |
+
k, v = self.kv_proj(x).split(self.n_embd, dim=2)
|
| 110 |
+
|
| 111 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 112 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(
|
| 113 |
+
1, 2
|
| 114 |
+
) # (B, nh, T, hs)
|
| 115 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(
|
| 116 |
+
1, 2
|
| 117 |
+
) # (B, nh, T, hs)
|
| 118 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(
|
| 119 |
+
1, 2
|
| 120 |
+
)
|
| 121 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
| 122 |
+
# y = F.scaled_dot_product_attention(
|
| 123 |
+
# q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
|
| 124 |
+
# )
|
| 125 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 126 |
+
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
| 127 |
+
att = F.softmax(att, dim=-1)
|
| 128 |
+
attn_save = att
|
| 129 |
+
att = self.attn_drop(att)
|
| 130 |
+
y = att @ v
|
| 131 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 132 |
+
|
| 133 |
+
# output projection
|
| 134 |
+
y = self.c_proj(y)
|
| 135 |
+
|
| 136 |
+
return y, attn_save
|
| 137 |
+
|
| 138 |
+
def find_multiple(n , k ) :
|
| 139 |
+
if n % k == 0:
|
| 140 |
+
return n
|
| 141 |
+
return n + k - (n % k)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class MLP(nn.Module):
|
| 145 |
+
def __init__(self, config ) :
|
| 146 |
+
super().__init__()
|
| 147 |
+
hidden_dim = 4 * config.n_embd * config.n_head
|
| 148 |
+
n_hidden = int(2 * hidden_dim / 3)
|
| 149 |
+
n_hidden = find_multiple(n_hidden, 256)
|
| 150 |
+
|
| 151 |
+
self.c_fc1 = nn.Linear(
|
| 152 |
+
config.n_embd , n_hidden, bias=False
|
| 153 |
+
)
|
| 154 |
+
self.c_fc2 = nn.Linear(
|
| 155 |
+
config.n_embd , n_hidden, bias=False
|
| 156 |
+
)
|
| 157 |
+
self.c_proj = nn.Linear(
|
| 158 |
+
n_hidden, config.n_embd , bias=False
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
| 163 |
+
x = self.c_proj(x)
|
| 164 |
+
return x
|
| 165 |
+
|
| 166 |
+
class Block(nn.Module):
|
| 167 |
+
""" an unassuming Transformer block """
|
| 168 |
+
|
| 169 |
+
def __init__(self, config):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.rms_1 = RMSNorm(config.n_embd )
|
| 172 |
+
self.rms_2 = RMSNorm(config.n_embd )
|
| 173 |
+
self.ln1 = nn.LayerNorm(config.n_embd)
|
| 174 |
+
self.ln2 = nn.LayerNorm(config.n_embd)
|
| 175 |
+
self.attn = CausalSelfAttention(config)
|
| 176 |
+
self.mlp = MLP(config)
|
| 177 |
+
def forward(self, x):
|
| 178 |
+
y, attn = self.attn(self.rms_1(x))
|
| 179 |
+
x = x + y
|
| 180 |
+
x = x + self.mlp(self.rms_2(x))
|
| 181 |
+
return x, attn
|
| 182 |
+
|
| 183 |
+
class GPT(nn.Module):
|
| 184 |
+
""" the full GPT language model, with a context size of block_size """
|
| 185 |
+
|
| 186 |
+
def __init__(self, config):
|
| 187 |
+
super().__init__()
|
| 188 |
+
|
| 189 |
+
# input embedding stem
|
| 190 |
+
self.config = config
|
| 191 |
+
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
| 192 |
+
self.type_emb = nn.Embedding(2, config.n_embd)
|
| 193 |
+
if config.num_props:
|
| 194 |
+
self.prop_nn = nn.Linear(config.num_props, config.n_embd)
|
| 195 |
+
|
| 196 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
| 197 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 198 |
+
# transformer
|
| 199 |
+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
| 200 |
+
# decoder head
|
| 201 |
+
self.ln_f = RMSNorm(config.n_embd )
|
| 202 |
+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 203 |
+
|
| 204 |
+
self.block_size = config.block_size
|
| 205 |
+
|
| 206 |
+
if config.lstm:
|
| 207 |
+
self.lstm = nn.LSTM(input_size = config.n_embd, hidden_size = config.n_embd, num_layers = config.lstm_layers, dropout = 0.3, bidirectional = False)
|
| 208 |
+
self.apply(self._init_weights)
|
| 209 |
+
|
| 210 |
+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
| 211 |
+
|
| 212 |
+
def get_block_size(self):
|
| 213 |
+
return self.block_size
|
| 214 |
+
|
| 215 |
+
def _init_weights(self, module):
|
| 216 |
+
if isinstance(module, nn.Linear):
|
| 217 |
+
torch.nn.init.normal_(
|
| 218 |
+
module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)
|
| 219 |
+
)
|
| 220 |
+
elif isinstance(module, nn.Embedding):
|
| 221 |
+
torch.nn.init.normal_(
|
| 222 |
+
module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def configure_optimizers(self, parameters, train_config):
|
| 226 |
+
|
| 227 |
+
optimizer = torch.optim.AdamW(parameters, lr=train_config.learning_rate, betas=train_config.betas)
|
| 228 |
+
return optimizer
|
| 229 |
+
|
| 230 |
+
def forward(self, idx, targets=None, prop = None, scaffold = None):
|
| 231 |
+
b, t = idx.size()
|
| 232 |
+
|
| 233 |
+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
| 234 |
+
|
| 235 |
+
if self.config.num_props:
|
| 236 |
+
assert prop.size(-1) == self.config.num_props, "Num_props should be equal to last dim of property vector"
|
| 237 |
+
|
| 238 |
+
# forward the GPT model
|
| 239 |
+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
| 240 |
+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
| 241 |
+
type_embeddings = self.type_emb(torch.ones((
|
| 242 |
+
b,t), dtype = torch.long, device = idx.device))
|
| 243 |
+
x = self.drop(token_embeddings + position_embeddings + type_embeddings)
|
| 244 |
+
|
| 245 |
+
if self.config.num_props:
|
| 246 |
+
type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
|
| 247 |
+
if prop.ndim == 2:
|
| 248 |
+
p = self.prop_nn(prop.unsqueeze(1)) # for single property
|
| 249 |
+
else:
|
| 250 |
+
p = self.prop_nn(prop) # for multiproperty
|
| 251 |
+
p += type_embd
|
| 252 |
+
x = torch.cat([p, x], 1)
|
| 253 |
+
|
| 254 |
+
if self.config.scaffold:
|
| 255 |
+
type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
|
| 256 |
+
|
| 257 |
+
scaffold_embeds = self.tok_emb(scaffold) # .mean(1, keepdim = True)
|
| 258 |
+
if self.config.lstm:
|
| 259 |
+
scaffold_embeds = self.lstm(scaffold_embeds.permute(1,0,2))[1][0]
|
| 260 |
+
# scaffold_embeds = scaffold_embeds.reshape(scaffold_embeds.shape[1], scaffold_embeds.shape[0], 2, self.config.n_embd).mean(2)
|
| 261 |
+
scaffold_embeds = scaffold_embeds.permute(1,0,2) # mean(0, keepdim = True)
|
| 262 |
+
# scaffold_embeds = scaffold_embeds.reshape(self.config.lstm_layers, 1, -1, self.config.n_embd)[-1].permute(1,0,2)
|
| 263 |
+
# scaffold_embeds = scaffold_embeds.reshape(scaffold_embeds.shape[1], scaffold_embeds.shape[0], self.config.n_embd)
|
| 264 |
+
scaffold_embeds += type_embd
|
| 265 |
+
x = torch.cat([scaffold_embeds, x], 1)
|
| 266 |
+
|
| 267 |
+
# x = self.blocks(x)
|
| 268 |
+
attn_maps = []
|
| 269 |
+
|
| 270 |
+
for layer in self.blocks:
|
| 271 |
+
x, attn = layer(x)
|
| 272 |
+
attn_maps.append(attn)
|
| 273 |
+
|
| 274 |
+
x = self.ln_f(x)
|
| 275 |
+
logits = self.head(x)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
if self.config.num_props and self.config.scaffold:
|
| 279 |
+
num = int(bool(self.config.num_props)) + int(self.config.scaffold_maxlen)
|
| 280 |
+
elif self.config.num_props:
|
| 281 |
+
num = int(bool(self.config.num_props))
|
| 282 |
+
elif self.config.scaffold:
|
| 283 |
+
num = int(self.config.scaffold_maxlen)
|
| 284 |
+
else:
|
| 285 |
+
num = 0
|
| 286 |
+
|
| 287 |
+
logits = logits[:, num:, :]
|
| 288 |
+
|
| 289 |
+
# if self.config.num_props or self.config.scaffold:
|
| 290 |
+
|
| 291 |
+
# num = int(bool(self.config.num_props)) + int(self.config.scaffold_maxlen) #int(self.config.lstm_layers) # int(self.config.scaffold) # int(self.config.scaffold)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# print(logits.shape)
|
| 295 |
+
|
| 296 |
+
# if we are given some desired targets also calculate the loss
|
| 297 |
+
loss = None
|
| 298 |
+
if targets is not None:
|
| 299 |
+
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1))
|
| 300 |
+
|
| 301 |
+
return logits, loss, attn_maps # (num_layers, batch_size, num_heads, max_seq_len, max_seq_len)
|
tool/comget/ppcenos.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"<": 0, "[#Branch1]": 1, "[#Branch2]": 2, "[#C]": 3, "[#N]": 4, "[=Branch1]": 5, "[=Branch2]": 6, "[=C]": 7, "[=N]": 8, "[=O]": 9, "[=Ring1]": 10, "[=Ring2]": 11, "[=S]": 12, "[Branch1]": 13, "[Branch2]": 14, "[C]": 15, "[Cl]": 16, "[F]": 17, "[GeH2]": 18, "[Ge]": 19, "[NH1]": 20, "[N]": 21, "[O]": 22, "[P]": 23, "[Ring1]": 24, "[Ring2]": 25, "[S]": 26, "[Se]": 27, "[nop]": 28}
|
tool/comget/ppcenos.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ddee4e16df14ee00e9736c66755977774edf1259c46afb03469c99ca7659fbf5
|
| 3 |
+
size 160173846
|
tool/comget/utils.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from .moses.utils import get_mol
|
| 7 |
+
from rdkit import Chem
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import threading
|
| 11 |
+
|
| 12 |
+
def set_seed(seed):
|
| 13 |
+
random.seed(seed)
|
| 14 |
+
np.random.seed(seed)
|
| 15 |
+
torch.manual_seed(seed)
|
| 16 |
+
torch.cuda.manual_seed_all(seed)
|
| 17 |
+
|
| 18 |
+
def top_k_logits(logits, k):
|
| 19 |
+
v, ix = torch.topk(logits, k)
|
| 20 |
+
out = logits.clone()
|
| 21 |
+
out[out < v[:, [-1]]] = -float('Inf')
|
| 22 |
+
return out
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None, prop = None, scaffold = None):
|
| 26 |
+
"""
|
| 27 |
+
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
|
| 28 |
+
the sequence, feeding the predictions back into the model each time. Clearly the sampling
|
| 29 |
+
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
|
| 30 |
+
of block_size, unlike an RNN that has an infinite context window.
|
| 31 |
+
"""
|
| 32 |
+
block_size = model.get_block_size()
|
| 33 |
+
model.eval()
|
| 34 |
+
|
| 35 |
+
for k in range(steps):
|
| 36 |
+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
|
| 37 |
+
logits, _, _ = model(x_cond, prop = prop, scaffold = scaffold) # for liggpt
|
| 38 |
+
# logits, _, _ = model(x_cond) # for char_rnn
|
| 39 |
+
# pluck the logits at the final step and scale by temperature
|
| 40 |
+
logits = logits[:, -1, :] / temperature
|
| 41 |
+
# optionally crop probabilities to only the top k options
|
| 42 |
+
if top_k is not None:
|
| 43 |
+
logits = top_k_logits(logits, top_k)
|
| 44 |
+
# apply softmax to convert to probabilities
|
| 45 |
+
probs = F.softmax(logits, dim=-1)
|
| 46 |
+
# sample from the distribution or take the most likely
|
| 47 |
+
if sample:
|
| 48 |
+
ix = torch.multinomial(probs, num_samples=1)
|
| 49 |
+
else:
|
| 50 |
+
_, ix = torch.topk(probs, k=1, dim=-1)
|
| 51 |
+
# append to the sequence and continue
|
| 52 |
+
x = torch.cat((x, ix), dim=1)
|
| 53 |
+
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
def check_novelty(gen_smiles, train_smiles): # gen: say 788, train: 120803
|
| 57 |
+
if len(gen_smiles) == 0:
|
| 58 |
+
novel_ratio = 0.
|
| 59 |
+
else:
|
| 60 |
+
duplicates = [1 for mol in gen_smiles if mol in train_smiles] # [1]*45
|
| 61 |
+
novel = len(gen_smiles) - sum(duplicates) # 788-45=743
|
| 62 |
+
novel_ratio = novel*100./len(gen_smiles) # 743*100/788=94.289
|
| 63 |
+
print("novelty: {:.3f}%".format(novel_ratio))
|
| 64 |
+
return novel_ratio
|
| 65 |
+
|
| 66 |
+
def canonic_smiles(smiles_or_mol):
|
| 67 |
+
mol = get_mol(smiles_or_mol)
|
| 68 |
+
if mol is None:
|
| 69 |
+
return None
|
| 70 |
+
return Chem.MolToSmiles(mol)
|
| 71 |
+
|
| 72 |
+
#Experimental Class for Smiles Enumeration, Iterator and SmilesIterator adapted from Keras 1.2.2
|
| 73 |
+
|
| 74 |
+
class Iterator(object):
|
| 75 |
+
"""Abstract base class for data iterators.
|
| 76 |
+
# Arguments
|
| 77 |
+
n: Integer, total number of samples in the dataset to loop over.
|
| 78 |
+
batch_size: Integer, size of a batch.
|
| 79 |
+
shuffle: Boolean, whether to shuffle the data between epochs.
|
| 80 |
+
seed: Random seeding for data shuffling.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, n, batch_size, shuffle, seed):
|
| 84 |
+
self.n = n
|
| 85 |
+
self.batch_size = batch_size
|
| 86 |
+
self.shuffle = shuffle
|
| 87 |
+
self.batch_index = 0
|
| 88 |
+
self.total_batches_seen = 0
|
| 89 |
+
self.lock = threading.Lock()
|
| 90 |
+
self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
|
| 91 |
+
if n < batch_size:
|
| 92 |
+
raise ValueError('Input data length is shorter than batch_size\nAdjust batch_size')
|
| 93 |
+
|
| 94 |
+
def reset(self):
|
| 95 |
+
self.batch_index = 0
|
| 96 |
+
|
| 97 |
+
def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
|
| 98 |
+
# Ensure self.batch_index is 0.
|
| 99 |
+
self.reset()
|
| 100 |
+
while 1:
|
| 101 |
+
if seed is not None:
|
| 102 |
+
np.random.seed(seed + self.total_batches_seen)
|
| 103 |
+
if self.batch_index == 0:
|
| 104 |
+
index_array = np.arange(n)
|
| 105 |
+
if shuffle:
|
| 106 |
+
index_array = np.random.permutation(n)
|
| 107 |
+
|
| 108 |
+
current_index = (self.batch_index * batch_size) % n
|
| 109 |
+
if n > current_index + batch_size:
|
| 110 |
+
current_batch_size = batch_size
|
| 111 |
+
self.batch_index += 1
|
| 112 |
+
else:
|
| 113 |
+
current_batch_size = n - current_index
|
| 114 |
+
self.batch_index = 0
|
| 115 |
+
self.total_batches_seen += 1
|
| 116 |
+
yield (index_array[current_index: current_index + current_batch_size],
|
| 117 |
+
current_index, current_batch_size)
|
| 118 |
+
|
| 119 |
+
def __iter__(self):
|
| 120 |
+
# Needed if we want to do something like:
|
| 121 |
+
# for x, y in data_gen.flow(...):
|
| 122 |
+
return self
|
| 123 |
+
|
| 124 |
+
def __next__(self, *args, **kwargs):
|
| 125 |
+
return self.next(*args, **kwargs)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class SmilesIterator(Iterator):
|
| 131 |
+
"""Iterator yielding data from a SMILES array.
|
| 132 |
+
# Arguments
|
| 133 |
+
x: Numpy array of SMILES input data.
|
| 134 |
+
y: Numpy array of targets data.
|
| 135 |
+
smiles_data_generator: Instance of `SmilesEnumerator`
|
| 136 |
+
to use for random SMILES generation.
|
| 137 |
+
batch_size: Integer, size of a batch.
|
| 138 |
+
shuffle: Boolean, whether to shuffle the data between epochs.
|
| 139 |
+
seed: Random seed for data shuffling.
|
| 140 |
+
dtype: dtype to use for returned batch. Set to keras.backend.floatx if using Keras
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self, x, y, smiles_data_generator,
|
| 144 |
+
batch_size=32, shuffle=False, seed=None,
|
| 145 |
+
dtype=np.float32
|
| 146 |
+
):
|
| 147 |
+
if y is not None and len(x) != len(y):
|
| 148 |
+
raise ValueError('X (images tensor) and y (labels) '
|
| 149 |
+
'should have the same length. '
|
| 150 |
+
'Found: X.shape = %s, y.shape = %s' %
|
| 151 |
+
(np.asarray(x).shape, np.asarray(y).shape))
|
| 152 |
+
|
| 153 |
+
self.x = np.asarray(x)
|
| 154 |
+
|
| 155 |
+
if y is not None:
|
| 156 |
+
self.y = np.asarray(y)
|
| 157 |
+
else:
|
| 158 |
+
self.y = None
|
| 159 |
+
self.smiles_data_generator = smiles_data_generator
|
| 160 |
+
self.dtype = dtype
|
| 161 |
+
super(SmilesIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)
|
| 162 |
+
|
| 163 |
+
def next(self):
|
| 164 |
+
"""For python 2.x.
|
| 165 |
+
# Returns
|
| 166 |
+
The next batch.
|
| 167 |
+
"""
|
| 168 |
+
# Keeps under lock only the mechanism which advances
|
| 169 |
+
# the indexing of each batch.
|
| 170 |
+
with self.lock:
|
| 171 |
+
index_array, current_index, current_batch_size = next(self.index_generator)
|
| 172 |
+
# The transformation of images is not under thread lock
|
| 173 |
+
# so it can be done in parallel
|
| 174 |
+
batch_x = np.zeros(tuple([current_batch_size] + [ self.smiles_data_generator.pad, self.smiles_data_generator._charlen]), dtype=self.dtype)
|
| 175 |
+
for i, j in enumerate(index_array):
|
| 176 |
+
smiles = self.x[j:j+1]
|
| 177 |
+
x = self.smiles_data_generator.transform(smiles)
|
| 178 |
+
batch_x[i] = x
|
| 179 |
+
|
| 180 |
+
if self.y is None:
|
| 181 |
+
return batch_x
|
| 182 |
+
batch_y = self.y[index_array]
|
| 183 |
+
return batch_x, batch_y
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class SmilesEnumerator(object):
|
| 187 |
+
"""SMILES Enumerator, vectorizer and devectorizer
|
| 188 |
+
|
| 189 |
+
#Arguments
|
| 190 |
+
charset: string containing the characters for the vectorization
|
| 191 |
+
can also be generated via the .fit() method
|
| 192 |
+
pad: Length of the vectorization
|
| 193 |
+
leftpad: Add spaces to the left of the SMILES
|
| 194 |
+
isomericSmiles: Generate SMILES containing information about stereogenic centers
|
| 195 |
+
enum: Enumerate the SMILES during transform
|
| 196 |
+
canonical: use canonical SMILES during transform (overrides enum)
|
| 197 |
+
"""
|
| 198 |
+
def __init__(self, charset = '@C)(=cOn1S2/H[N]\\', pad=120, leftpad=True, isomericSmiles=True, enum=True, canonical=False):
|
| 199 |
+
self._charset = None
|
| 200 |
+
self.charset = charset
|
| 201 |
+
self.pad = pad
|
| 202 |
+
self.leftpad = leftpad
|
| 203 |
+
self.isomericSmiles = isomericSmiles
|
| 204 |
+
self.enumerate = enum
|
| 205 |
+
self.canonical = canonical
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def charset(self):
|
| 209 |
+
return self._charset
|
| 210 |
+
|
| 211 |
+
@charset.setter
|
| 212 |
+
def charset(self, charset):
|
| 213 |
+
self._charset = charset
|
| 214 |
+
self._charlen = len(charset)
|
| 215 |
+
self._char_to_int = dict((c,i) for i,c in enumerate(charset))
|
| 216 |
+
self._int_to_char = dict((i,c) for i,c in enumerate(charset))
|
| 217 |
+
|
| 218 |
+
def fit(self, smiles, extra_chars=[], extra_pad = 5):
|
| 219 |
+
"""Performs extraction of the charset and length of a SMILES datasets and sets self.pad and self.charset
|
| 220 |
+
|
| 221 |
+
#Arguments
|
| 222 |
+
smiles: Numpy array or Pandas series containing smiles as strings
|
| 223 |
+
extra_chars: List of extra chars to add to the charset (e.g. "\\\\" when "/" is present)
|
| 224 |
+
extra_pad: Extra padding to add before or after the SMILES vectorization
|
| 225 |
+
"""
|
| 226 |
+
charset = set("".join(list(smiles)))
|
| 227 |
+
self.charset = "".join(charset.union(set(extra_chars)))
|
| 228 |
+
self.pad = max([len(smile) for smile in smiles]) + extra_pad
|
| 229 |
+
|
| 230 |
+
def randomize_smiles(self, smiles):
|
| 231 |
+
"""Perform a randomization of a SMILES string
|
| 232 |
+
must be RDKit sanitizable"""
|
| 233 |
+
m = Chem.MolFromSmiles(smiles)
|
| 234 |
+
ans = list(range(m.GetNumAtoms()))
|
| 235 |
+
np.random.shuffle(ans)
|
| 236 |
+
nm = Chem.RenumberAtoms(m,ans)
|
| 237 |
+
return Chem.MolToSmiles(nm, canonical=self.canonical, isomericSmiles=self.isomericSmiles)
|
| 238 |
+
|
| 239 |
+
def transform(self, smiles):
|
| 240 |
+
"""Perform an enumeration (randomization) and vectorization of a Numpy array of smiles strings
|
| 241 |
+
#Arguments
|
| 242 |
+
smiles: Numpy array or Pandas series containing smiles as strings
|
| 243 |
+
"""
|
| 244 |
+
one_hot = np.zeros((smiles.shape[0], self.pad, self._charlen),dtype=np.int8)
|
| 245 |
+
|
| 246 |
+
if self.leftpad:
|
| 247 |
+
for i,ss in enumerate(smiles):
|
| 248 |
+
if self.enumerate: ss = self.randomize_smiles(ss)
|
| 249 |
+
l = len(ss)
|
| 250 |
+
diff = self.pad - l
|
| 251 |
+
for j,c in enumerate(ss):
|
| 252 |
+
one_hot[i,j+diff,self._char_to_int[c]] = 1
|
| 253 |
+
return one_hot
|
| 254 |
+
else:
|
| 255 |
+
for i,ss in enumerate(smiles):
|
| 256 |
+
if self.enumerate: ss = self.randomize_smiles(ss)
|
| 257 |
+
for j,c in enumerate(ss):
|
| 258 |
+
one_hot[i,j,self._char_to_int[c]] = 1
|
| 259 |
+
return one_hot
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def reverse_transform(self, vect):
|
| 263 |
+
""" Performs a conversion of a vectorized SMILES to a smiles strings
|
| 264 |
+
charset must be the same as used for vectorization.
|
| 265 |
+
#Arguments
|
| 266 |
+
vect: Numpy array of vectorized SMILES.
|
| 267 |
+
"""
|
| 268 |
+
smiles = []
|
| 269 |
+
for v in vect:
|
| 270 |
+
#mask v
|
| 271 |
+
v=v[v.sum(axis=1)==1]
|
| 272 |
+
#Find one hot encoded index with argmax, translate to char and join to string
|
| 273 |
+
smile = "".join(self._int_to_char[i] for i in v.argmax(axis=1))
|
| 274 |
+
smiles.append(smile)
|
| 275 |
+
return np.array(smiles)
|
tool/converters.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.tools import BaseTool
|
| 2 |
+
|
| 3 |
+
from tool.chemspace import ChemSpace
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
from utils import (
|
| 7 |
+
is_multiple_smiles,
|
| 8 |
+
is_smiles,
|
| 9 |
+
pubchem_query2smiles,
|
| 10 |
+
query2cas,
|
| 11 |
+
smiles2name,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Query2CAS(BaseTool):
|
| 16 |
+
name:str = "Mol2CAS"
|
| 17 |
+
description:str = "Input molecule (name or SMILES), returns CAS number."
|
| 18 |
+
url_cid: str = None
|
| 19 |
+
url_data: str = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.url_cid = (
|
| 27 |
+
"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/{}/{}/cids/JSON"
|
| 28 |
+
)
|
| 29 |
+
self.url_data = (
|
| 30 |
+
"https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/data/compound/{}/JSON"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def _run(self, query: str) -> str:
|
| 34 |
+
try:
|
| 35 |
+
# if query is smiles
|
| 36 |
+
smiles = None
|
| 37 |
+
if is_smiles(query):
|
| 38 |
+
smiles = query
|
| 39 |
+
try:
|
| 40 |
+
cas = query2cas(query, self.url_cid, self.url_data)
|
| 41 |
+
except ValueError as e:
|
| 42 |
+
return str(e)
|
| 43 |
+
if smiles is None:
|
| 44 |
+
try:
|
| 45 |
+
smiles = pubchem_query2smiles(cas, None)
|
| 46 |
+
except ValueError as e:
|
| 47 |
+
return str(e)
|
| 48 |
+
|
| 49 |
+
return cas
|
| 50 |
+
except ValueError:
|
| 51 |
+
return "CAS number not found"
|
| 52 |
+
|
| 53 |
+
async def _arun(self, query: str) -> str:
|
| 54 |
+
"""Use the tool asynchronously."""
|
| 55 |
+
raise NotImplementedError()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Query2SMILES(BaseTool):
|
| 59 |
+
name:str = "CAS2SMILES"
|
| 60 |
+
description :str = "Input a CAS number, returns SMILES."
|
| 61 |
+
url: str = None
|
| 62 |
+
chemspace_api_key: str = None
|
| 63 |
+
|
| 64 |
+
def __init__(self, chemspace_api_key: str = None):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.chemspace_api_key = chemspace_api_key
|
| 67 |
+
self.url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}"
|
| 68 |
+
|
| 69 |
+
def _run(self, query: str) -> str:
|
| 70 |
+
"""This function queries the given molecule name and returns a SMILES string from the record"""
|
| 71 |
+
"""Useful to get the SMILES string of one molecule by searching the name of a molecule. Only query with one specific name."""
|
| 72 |
+
if is_smiles(query) and is_multiple_smiles(query):
|
| 73 |
+
return "Multiple SMILES strings detected, input one molecule at a time."
|
| 74 |
+
try:
|
| 75 |
+
smi = pubchem_query2smiles(query, self.url)
|
| 76 |
+
except Exception as e:
|
| 77 |
+
if self.chemspace_api_key:
|
| 78 |
+
try:
|
| 79 |
+
chemspace = ChemSpace(self.chemspace_api_key)
|
| 80 |
+
smi = chemspace.convert_mol_rep(query, "smiles")
|
| 81 |
+
smi = smi.split(":")[1]
|
| 82 |
+
except Exception:
|
| 83 |
+
return str(e)
|
| 84 |
+
else:
|
| 85 |
+
try:
|
| 86 |
+
|
| 87 |
+
smi = chemspace.convert_mol_rep(query, "smiles")
|
| 88 |
+
smi = smi.split(":")[1]
|
| 89 |
+
except Exception:
|
| 90 |
+
return str(e)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
return smi
|
| 94 |
+
|
| 95 |
+
async def _arun(self, query: str) -> str:
|
| 96 |
+
"""Use the tool asynchronously."""
|
| 97 |
+
raise NotImplementedError()
|
| 98 |
+
|
| 99 |
+
class Mol2SMILES(BaseTool):
|
| 100 |
+
name:str = "Mol2SMILES"
|
| 101 |
+
description :str = "Input a molecular name , returns SMILES."
|
| 102 |
+
|
| 103 |
+
def __init__(self, chemspace_api_key: str = None):
|
| 104 |
+
super().__init__()
|
| 105 |
+
|
| 106 |
+
def _run(self, query: str) -> str:
|
| 107 |
+
"""This function queries the given molecule name and returns a SMILES string from the record"""
|
| 108 |
+
"""Useful to get the SMILES string of one molecule by searching the name of a molecule. Only query with one specific name."""
|
| 109 |
+
if is_smiles(query) and is_multiple_smiles(query):
|
| 110 |
+
return "Multiple SMILES strings detected, input one molecule at a time."
|
| 111 |
+
try:
|
| 112 |
+
smi = pubchem_query2smiles(query )
|
| 113 |
+
return smi
|
| 114 |
+
except Exception as e:
|
| 115 |
+
try:
|
| 116 |
+
csv_data = pd.read_csv('tool/dataset.csv',encoding='ISO-8859-1')
|
| 117 |
+
relevant_rows = csv_data[csv_data['Name']==(query)]
|
| 118 |
+
if not relevant_rows.empty:
|
| 119 |
+
# Get the most relevant answer (assuming we return the first match)
|
| 120 |
+
return relevant_rows.iloc[0]['SMILES']
|
| 121 |
+
except:
|
| 122 |
+
return str(e)
|
| 123 |
+
|
| 124 |
+
async def _arun(self, query: str) -> str:
|
| 125 |
+
"""Use the tool asynchronously."""
|
| 126 |
+
raise NotImplementedError()
|
| 127 |
+
|
| 128 |
+
class SMILES2Name(BaseTool):
|
| 129 |
+
name:str = "SMILES2Name"
|
| 130 |
+
description:str = "Input SMILES, returns molecule name."
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def __init__(self):
|
| 135 |
+
super().__init__()
|
| 136 |
+
|
| 137 |
+
def _run(self, query: str) -> str:
|
| 138 |
+
"""Use the tool."""
|
| 139 |
+
try:
|
| 140 |
+
if not is_smiles(query):
|
| 141 |
+
try:
|
| 142 |
+
query2smiles = Query2SMILES()
|
| 143 |
+
query = query2smiles.run(query)
|
| 144 |
+
except:
|
| 145 |
+
raise ValueError("Invalid molecule input, no Pubchem entry")
|
| 146 |
+
name = smiles2name(query)
|
| 147 |
+
|
| 148 |
+
return name
|
| 149 |
+
except Exception as e:
|
| 150 |
+
return "Error: " + str(e)
|
| 151 |
+
|
| 152 |
+
async def _arun(self, query: str) -> str:
|
| 153 |
+
"""Use the tool asynchronously."""
|
| 154 |
+
raise NotImplementedError()
|
tool/csv_search.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Mon Dec 23 16:18:29 2024
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from langchain.tools import BaseTool
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
class search_csv(BaseTool):
|
| 12 |
+
name = "csvsearch"
|
| 13 |
+
description = (
|
| 14 |
+
"input name, return the SMILES of materials "
|
| 15 |
+
"convert name to SMILES."
|
| 16 |
+
)
|
| 17 |
+
llm: BaseLanguageModel = None
|
| 18 |
+
openai_api_key: str = None
|
| 19 |
+
semantic_scholar_api_key: str = None
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
def _run(self, smiles: str) -> str:
|
| 24 |
+
csv_data = pd.read_csv('dataset.csv',encoding='ISO-8859-1')
|
| 25 |
+
relevant_rows = csv_data[csv_data['Name']==(query)]
|
| 26 |
+
|
| 27 |
+
if not relevant_rows.empty:
|
| 28 |
+
# Get the most relevant answer (assuming we return the first match)
|
| 29 |
+
return relevant_rows.iloc[0]['SMILES']
|
| 30 |
+
return None
|
| 31 |
+
async def _arun(self, smiles: str) -> str:
|
| 32 |
+
"""Use the tool asynchronously."""
|
| 33 |
+
raise NotImplementedError()
|
| 34 |
+
|
tool/dap/.gitignore
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
tool/dap/OSC/test.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1081233b2f0b3c77752a98b3c9e4ae065cb21aae4e3e5d31f8d673a1c2069ded
|
| 3 |
+
size 81596523
|
tool/dap/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# database
|
tool/dap/config/config_hparam.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{ "name": "biomarker_log",
|
| 2 |
+
|
| 3 |
+
"d_model_name" : "DeepChem/ChemBERTa-10M-MTR",
|
| 4 |
+
"p_model_name" : "DeepChem/ChemBERTa-77M-MLM",
|
| 5 |
+
"gpu_ids" : "0",
|
| 6 |
+
"model_mode" : "train",
|
| 7 |
+
"load_checkpoint" : "./checkpoint/bindingDB/test.ckpt",
|
| 8 |
+
|
| 9 |
+
"prot_maxlength" : 360,
|
| 10 |
+
"layer_limit" : true,
|
| 11 |
+
|
| 12 |
+
"max_epoch": 16,
|
| 13 |
+
"batch_size": 40,
|
| 14 |
+
"num_workers": 0,
|
| 15 |
+
|
| 16 |
+
"task_name" : "OSC",
|
| 17 |
+
"lr": 1e-4,
|
| 18 |
+
"layer_features" : [512, 128, 64, 1],
|
| 19 |
+
"dropout" : 0.1,
|
| 20 |
+
"loss_fn" : "MSE",
|
| 21 |
+
|
| 22 |
+
"traindata_rate" : 1.0,
|
| 23 |
+
"pretrained": {"chem":true, "prot":true},
|
| 24 |
+
"num_seed" : 111
|
| 25 |
+
}
|
| 26 |
+
|
tool/dap/config/predict.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{ "name": "biomarker_log",
|
| 2 |
+
|
| 3 |
+
"d_model_name" : "DeepChem/ChemBERTa-10M-MLM",
|
| 4 |
+
"p_model_name" : "DeepChem/ChemBERTa-10M-MTR",
|
| 5 |
+
"gpu_ids" : "0",
|
| 6 |
+
"model_mode" : "test",
|
| 7 |
+
"load_checkpoint" : "tool/dap/OSC/test.ckpt",
|
| 8 |
+
|
| 9 |
+
"prot_maxlength" : 360,
|
| 10 |
+
"layer_limit" : true,
|
| 11 |
+
|
| 12 |
+
"max_epoch": 16,
|
| 13 |
+
"batch_size": 40,
|
| 14 |
+
"num_workers": 0,
|
| 15 |
+
|
| 16 |
+
"task_name" : "OSC",
|
| 17 |
+
"lr": 1e-4,
|
| 18 |
+
"layer_features" : [128, 128, 128, 1],
|
| 19 |
+
"dropout" : 0.1,
|
| 20 |
+
"loss_fn" : "MSE",
|
| 21 |
+
|
| 22 |
+
"traindata_rate" : 1.0,
|
| 23 |
+
"pretrained": {"chem":true, "prot":true},
|
| 24 |
+
"num_seed" : 111
|
| 25 |
+
}
|
| 26 |
+
|
tool/dap/requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
altair
|
| 2 |
+
streamlit
|
| 3 |
+
streamlit-ketcher
|
| 4 |
+
torch
|
| 5 |
+
tqdm
|
| 6 |
+
transformers
|
| 7 |
+
pytorch_lightning
|
| 8 |
+
scipy
|
| 9 |
+
pandas
|
| 10 |
+
rdkit
|
| 11 |
+
scikit-learn
|
| 12 |
+
matplotlib
|
| 13 |
+
easydict
|
| 14 |
+
wandb
|
| 15 |
+
networkx
|
| 16 |
+
seaborn
|
| 17 |
+
|
| 18 |
+
|
tool/dap/run.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from .util.utils import *
|
| 9 |
+
from rdkit import Chem
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from .train import markerModel
|
| 12 |
+
|
| 13 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 14 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0 '
|
| 15 |
+
|
| 16 |
+
device_count = torch.cuda.device_count()
|
| 17 |
+
device_biomarker = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
device = torch.device('cpu')
|
| 20 |
+
a_model_name = 'DeepChem/ChemBERTa-10M-MLM'
|
| 21 |
+
d_model_name = 'DeepChem/ChemBERTa-10M-MTR'
|
| 22 |
+
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(a_model_name)
|
| 24 |
+
d_tokenizer = AutoTokenizer.from_pretrained(d_model_name)
|
| 25 |
+
|
| 26 |
+
#--biomarker Model
|
| 27 |
+
##-- hyper param config file Load --##
|
| 28 |
+
config = load_hparams('tool/dap/config/predict.json')
|
| 29 |
+
config = DictX(config)
|
| 30 |
+
model = markerModel(config.d_model_name, config.p_model_name,
|
| 31 |
+
config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
|
| 32 |
+
|
| 33 |
+
model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False)
|
| 34 |
+
model.eval()
|
| 35 |
+
model.freeze()
|
| 36 |
+
|
| 37 |
+
if device_biomarker.type == 'cuda':
|
| 38 |
+
model = torch.nn.DataParallel(model)
|
| 39 |
+
|
| 40 |
+
def get_marker(drug_inputs, prot_inputs):
|
| 41 |
+
output_preds = model(drug_inputs, prot_inputs)
|
| 42 |
+
|
| 43 |
+
predict = torch.squeeze( (output_preds)).tolist()
|
| 44 |
+
|
| 45 |
+
# output_preds = torch.relu(output_preds)
|
| 46 |
+
# predict = torch.tanh(output_preds)
|
| 47 |
+
# predict = predict.squeeze(dim=1).tolist()
|
| 48 |
+
|
| 49 |
+
return predict
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def marker_prediction(smiles, aas):
|
| 53 |
+
try:
|
| 54 |
+
aas_input = []
|
| 55 |
+
for ass_data in aas:
|
| 56 |
+
aas_input.append(' '.join(list(ass_data)))
|
| 57 |
+
|
| 58 |
+
a_inputs = tokenizer(smiles, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
|
| 59 |
+
# d_inputs = tokenizer(smiles, truncation=True, return_tensors="pt")
|
| 60 |
+
a_input_ids = a_inputs['input_ids'].to(device)
|
| 61 |
+
a_attention_mask = a_inputs['attention_mask'].to(device)
|
| 62 |
+
a_inputs = {'input_ids': a_input_ids, 'attention_mask': a_attention_mask}
|
| 63 |
+
|
| 64 |
+
d_inputs = d_tokenizer(aas_input, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
|
| 65 |
+
# p_inputs = prot_tokenizer(aas_input, truncation=True, return_tensors="pt")
|
| 66 |
+
d_input_ids = d_inputs['input_ids'].to(device)
|
| 67 |
+
d_attention_mask = d_inputs['attention_mask'].to(device)
|
| 68 |
+
d_inputs = {'input_ids': d_input_ids, 'attention_mask': d_attention_mask}
|
| 69 |
+
|
| 70 |
+
output_list = get_marker(a_inputs, d_inputs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
return output_list
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(e)
|
| 77 |
+
return {'Error_message': e}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def smiles_aas_test(smile_acc,smile_don):
|
| 81 |
+
|
| 82 |
+
mola = Chem.MolFromSmiles(smile_acc)
|
| 83 |
+
smile_acc = Chem.MolToSmiles(mola, canonical=True)
|
| 84 |
+
mold = Chem.MolFromSmiles(smile_don)
|
| 85 |
+
smile_don = Chem.MolToSmiles(mold, canonical=True)
|
| 86 |
+
|
| 87 |
+
batch_size = 1
|
| 88 |
+
|
| 89 |
+
datas = []
|
| 90 |
+
marker_list = []
|
| 91 |
+
marker_datas = []
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
marker_datas.append([smile_acc,smile_don])
|
| 96 |
+
if len(marker_datas) == batch_size:
|
| 97 |
+
marker_list.append(list(marker_datas))
|
| 98 |
+
marker_datas.clear()
|
| 99 |
+
|
| 100 |
+
if len(marker_datas) != 0:
|
| 101 |
+
marker_list.append(list(marker_datas))
|
| 102 |
+
marker_datas.clear()
|
| 103 |
+
|
| 104 |
+
for marker_datas in tqdm(marker_list, total=len(marker_list)):
|
| 105 |
+
smiles_d , smiles_a = zip(*marker_datas)
|
| 106 |
+
output_pred = marker_prediction(list(smiles_d), list(smiles_a) )
|
| 107 |
+
if len(datas) == 0:
|
| 108 |
+
datas = output_pred
|
| 109 |
+
else:
|
| 110 |
+
datas = datas + output_pred
|
| 111 |
+
|
| 112 |
+
# ## -- Export result data to csv -- ##
|
| 113 |
+
# df = pd.DataFrame(datas)
|
| 114 |
+
# df.to_csv('./results/predictData_nontonon_bindingdb_test.csv', index=None)
|
| 115 |
+
|
| 116 |
+
# print(df)
|
| 117 |
+
return datas
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == '__main__':
|
| 122 |
+
|
| 123 |
+
a = smiles_aas_test('CC(C)CCCC(C)CCC1=C(/C=C2\C(=O)C3=C(C=C(F)C(F)=C3)C2=C(C#N)C#N)SC2=C1N(CCC(C)CCCC(C)C)C1=C2C2=NSN=C2C2=C1N(CCC(C)CCCC(C)C)C1=C2SC(/C=C2\C(=O)C3=C(C=C(F)C(F)=C3)C2=C(C#N)C#N)=C1CCC(C)CCCC(C)C','CCCCC(CC)CC1=C(F)C=C(C2=C3C=C(C4=CC=C(C5=C6C(=O)C7=C(CC(CC)CCCC)SC(CC(CC)CCCC)=C7C(=O)C6=C(C6=CC=C(C)S6)S5)S4)SC3=C(C3=CC(F)=C(CC(CC)CCCC)S3)C3=C2SC(C)=C3)S1')
|
| 124 |
+
|
tool/dap/screen.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from rdkit import Chem
|
| 8 |
+
from .util.utils import *
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from .train import markerModel
|
| 12 |
+
|
| 13 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
| 14 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = '0 '
|
| 15 |
+
|
| 16 |
+
device_count = torch.cuda.device_count()
|
| 17 |
+
device_biberta= torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
device = torch.device('cpu')
|
| 20 |
+
a_model_name = 'DeepChem/ChemBERTa-10M-MLM'
|
| 21 |
+
d_model_name = 'DeepChem/ChemBERTa-10M-MTR'
|
| 22 |
+
|
| 23 |
+
tokenizer = AutoTokenizer.from_pretrained(a_model_name)
|
| 24 |
+
d_tokenizer = AutoTokenizer.from_pretrained(d_model_name)
|
| 25 |
+
|
| 26 |
+
#--bibertaModel
|
| 27 |
+
##-- hyper param config file Load --##
|
| 28 |
+
|
| 29 |
+
config = load_hparams('tool/dap/config/predict.json')
|
| 30 |
+
config = DictX(config)
|
| 31 |
+
model = markerModel(config.d_model_name, config.p_model_name,
|
| 32 |
+
config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
|
| 33 |
+
|
| 34 |
+
model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False)
|
| 35 |
+
model.eval()
|
| 36 |
+
model.freeze()
|
| 37 |
+
|
| 38 |
+
if device_biberta.type == 'cuda':
|
| 39 |
+
model = torch.nn.DataParallel(model)
|
| 40 |
+
|
| 41 |
+
def get_biberta(drug_inputs, prot_inputs):
|
| 42 |
+
output_preds = model(drug_inputs, prot_inputs)
|
| 43 |
+
|
| 44 |
+
predict = torch.squeeze( (output_preds)).tolist()
|
| 45 |
+
|
| 46 |
+
# output_preds = torch.relu(output_preds)
|
| 47 |
+
# predict = torch.tanh(output_preds)
|
| 48 |
+
# predict = predict.squeeze(dim=1).tolist()
|
| 49 |
+
|
| 50 |
+
return predict
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def biberta_prediction(smiles, aas):
|
| 54 |
+
try:
|
| 55 |
+
aas_input = []
|
| 56 |
+
for ass_data in aas:
|
| 57 |
+
aas_input.append(' '.join(list(ass_data)))
|
| 58 |
+
|
| 59 |
+
a_inputs = tokenizer(smiles, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
|
| 60 |
+
# d_inputs = tokenizer(smiles, truncation=True, return_tensors="pt")
|
| 61 |
+
a_input_ids = a_inputs['input_ids'].to(device)
|
| 62 |
+
a_attention_mask = a_inputs['attention_mask'].to(device)
|
| 63 |
+
a_inputs = {'input_ids': a_input_ids, 'attention_mask': a_attention_mask}
|
| 64 |
+
|
| 65 |
+
d_inputs = d_tokenizer(aas_input, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
|
| 66 |
+
# p_inputs = prot_tokenizer(aas_input, truncation=True, return_tensors="pt")
|
| 67 |
+
d_input_ids = d_inputs['input_ids'].to(device)
|
| 68 |
+
d_attention_mask = d_inputs['attention_mask'].to(device)
|
| 69 |
+
d_inputs = {'input_ids': d_input_ids, 'attention_mask': d_attention_mask}
|
| 70 |
+
|
| 71 |
+
output_predict = get_biberta(a_inputs, d_inputs)
|
| 72 |
+
|
| 73 |
+
output_list = [{'acceptor': smiles[i], 'donor': aas[i], 'predict': output_predict[i]} for i in range(0,len(aas))]
|
| 74 |
+
|
| 75 |
+
return output_list
|
| 76 |
+
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(e)
|
| 79 |
+
return {'Error_message': e}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def smiles_aas_test(file):
|
| 83 |
+
|
| 84 |
+
batch_se = 80
|
| 85 |
+
try:
|
| 86 |
+
datas = []
|
| 87 |
+
biberta_list = []
|
| 88 |
+
biberta_datas = []
|
| 89 |
+
|
| 90 |
+
smiles_aas = pd.read_csv(file)
|
| 91 |
+
|
| 92 |
+
smiles_d , smiles_a = (smiles_aas['donor'],smiles_aas['acceptor'])
|
| 93 |
+
|
| 94 |
+
donor,acceptor =[],[]
|
| 95 |
+
for i in smiles_d:
|
| 96 |
+
s = Chem.MolToSmiles(Chem.MolFromSmiles(i))
|
| 97 |
+
donor.append(s)
|
| 98 |
+
for i in smiles_a:
|
| 99 |
+
s = Chem.MolToSmiles(Chem.MolFromSmiles(i))
|
| 100 |
+
acceptor.append(s)
|
| 101 |
+
|
| 102 |
+
output_pred = biberta_prediction(list(acceptor), list(donor) )
|
| 103 |
+
if len(datas) == 0:
|
| 104 |
+
datas = output_pred
|
| 105 |
+
else:
|
| 106 |
+
datas = datas + output_pred
|
| 107 |
+
|
| 108 |
+
# ## -- Export result data to csv -- ##
|
| 109 |
+
df = pd.DataFrame(datas)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# print(df)
|
| 113 |
+
return datas
|
| 114 |
+
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(e)
|
| 117 |
+
return {'Error_message': e}
|
| 118 |
+
|
tool/dap/train.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 3 |
+
|
| 4 |
+
import gc, os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from scipy.stats import pearsonr
|
| 9 |
+
from .util.utils import *
|
| 10 |
+
#from .util.attention_flow import *
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
import sklearn as sk
|
| 16 |
+
from torch.utils.data import Dataset, DataLoader
|
| 17 |
+
|
| 18 |
+
import pytorch_lightning as pl
|
| 19 |
+
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
|
| 20 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
| 21 |
+
from transformers import AutoConfig, AutoTokenizer, RobertaModel, BertModel
|
| 22 |
+
from sklearn.metrics import r2_score, mean_absolute_error,mean_squared_error
|
| 23 |
+
|
| 24 |
+
class markerDataset(Dataset):
|
| 25 |
+
def __init__(self, list_IDs, labels, df_dti, d_tokenizer, p_tokenizer):
|
| 26 |
+
'Initialization'
|
| 27 |
+
self.labels = labels
|
| 28 |
+
self.list_IDs = list_IDs
|
| 29 |
+
self.df = df_dti
|
| 30 |
+
|
| 31 |
+
self.d_tokenizer = d_tokenizer
|
| 32 |
+
self.p_tokenizer = p_tokenizer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def convert_data(self, acc_data, don_data):
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
d_inputs = self.d_tokenizer(acc_data, return_tensors="pt")
|
| 40 |
+
p_inputs = self.d_tokenizer(don_data, return_tensors="pt")
|
| 41 |
+
|
| 42 |
+
acc_input_ids = d_inputs['input_ids']
|
| 43 |
+
acc_attention_mask = d_inputs['attention_mask']
|
| 44 |
+
acc_inputs = {'input_ids': acc_input_ids, 'attention_mask': acc_attention_mask}
|
| 45 |
+
|
| 46 |
+
don_input_ids = p_inputs['input_ids']
|
| 47 |
+
don_attention_mask = p_inputs['attention_mask']
|
| 48 |
+
don_inputs = {'input_ids': don_input_ids, 'attention_mask': don_attention_mask}
|
| 49 |
+
|
| 50 |
+
return acc_inputs, don_inputs
|
| 51 |
+
|
| 52 |
+
def tokenize_data(self, acc_data, don_data):
|
| 53 |
+
|
| 54 |
+
tokenize_acc = ['[CLS]'] + self.d_tokenizer.tokenize(acc_data) + ['[SEP]']
|
| 55 |
+
|
| 56 |
+
tokenize_don = ['[CLS]'] + self.p_tokenizer.tokenize(don_data) + ['[SEP]']
|
| 57 |
+
|
| 58 |
+
return tokenize_acc, tokenize_don
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
'Denotes the total number of samples'
|
| 62 |
+
return len(self.list_IDs)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, index):
|
| 65 |
+
'Generates one sample of data'
|
| 66 |
+
index = self.list_IDs[index]
|
| 67 |
+
acc_data = self.df.iloc[index]['acceptor']
|
| 68 |
+
don_data = self.df.iloc[index]['donor']
|
| 69 |
+
|
| 70 |
+
d_inputs = self.d_tokenizer(acc_data, padding='max_length', max_length=400, truncation=True, return_tensors="pt")
|
| 71 |
+
p_inputs = self.p_tokenizer(don_data, padding='max_length', max_length=400, truncation=True, return_tensors="pt")
|
| 72 |
+
|
| 73 |
+
d_input_ids = d_inputs['input_ids'].squeeze()
|
| 74 |
+
d_attention_mask = d_inputs['attention_mask'].squeeze()
|
| 75 |
+
p_input_ids = p_inputs['input_ids'].squeeze()
|
| 76 |
+
p_attention_mask = p_inputs['attention_mask'].squeeze()
|
| 77 |
+
|
| 78 |
+
labels = torch.as_tensor(self.labels[index], dtype=torch.float)
|
| 79 |
+
|
| 80 |
+
dataset = [d_input_ids, d_attention_mask, p_input_ids, p_attention_mask, labels]
|
| 81 |
+
return dataset
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class markerDataModule(pl.LightningDataModule):
|
| 85 |
+
def __init__(self, task_name, acc_model_name, don_model_name, num_workers, batch_size, traindata_rate = 1.0):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.batch_size = batch_size
|
| 88 |
+
self.num_workers = num_workers
|
| 89 |
+
self.task_name = task_name
|
| 90 |
+
|
| 91 |
+
self.traindata_rate = traindata_rate
|
| 92 |
+
|
| 93 |
+
self.d_tokenizer = AutoTokenizer.from_pretrained(acc_model_name)
|
| 94 |
+
self.p_tokenizer = AutoTokenizer.from_pretrained(don_model_name)
|
| 95 |
+
|
| 96 |
+
self.df_train = None
|
| 97 |
+
self.df_val = None
|
| 98 |
+
self.df_test = None
|
| 99 |
+
|
| 100 |
+
self.load_testData = True
|
| 101 |
+
|
| 102 |
+
self.train_dataset = None
|
| 103 |
+
self.valid_dataset = None
|
| 104 |
+
self.test_dataset = None
|
| 105 |
+
|
| 106 |
+
def get_task(self, task_name):
|
| 107 |
+
if task_name.lower() == 'OSC':
|
| 108 |
+
return './dataset/OSC/'
|
| 109 |
+
|
| 110 |
+
elif task_name.lower() == 'merge':
|
| 111 |
+
self.load_testData = False
|
| 112 |
+
return './dataset/MergeDataset'
|
| 113 |
+
|
| 114 |
+
def prepare_data(self):
|
| 115 |
+
# Use this method to do things that might write to disk or that need to be done only from
|
| 116 |
+
# a single process in distributed settings.
|
| 117 |
+
dataFolder = './dataset/OSC'
|
| 118 |
+
|
| 119 |
+
self.df_train = pd.read_csv(dataFolder + '/train.csv')
|
| 120 |
+
self.df_val = pd.read_csv(dataFolder + '/val.csv')
|
| 121 |
+
|
| 122 |
+
## -- Data Lenght Rate apply -- ##
|
| 123 |
+
traindata_length = int(len(self.df_train) * self.traindata_rate)
|
| 124 |
+
validdata_length = int(len(self.df_val) * self.traindata_rate)
|
| 125 |
+
|
| 126 |
+
self.df_train = self.df_train[:traindata_length]
|
| 127 |
+
self.df_val = self.df_val[:validdata_length]
|
| 128 |
+
|
| 129 |
+
if self.load_testData is True:
|
| 130 |
+
self.df_test = pd.read_csv(dataFolder + '/test.csv')
|
| 131 |
+
|
| 132 |
+
def setup(self, stage=None):
|
| 133 |
+
if stage == 'fit' or stage is None:
|
| 134 |
+
self.train_dataset = markerDataset(self.df_train.index.values, self.df_train.Label.values, self.df_train,
|
| 135 |
+
self.d_tokenizer, self.p_tokenizer)
|
| 136 |
+
self.valid_dataset = markerDataset(self.df_val.index.values, self.df_val.Label.values, self.df_val,
|
| 137 |
+
self.d_tokenizer, self.p_tokenizer)
|
| 138 |
+
|
| 139 |
+
if self.load_testData is True:
|
| 140 |
+
self.test_dataset = markerDataset(self.df_test.index.values, self.df_test.Label.values, self.df_test,
|
| 141 |
+
self.d_tokenizer, self.p_tokenizer)
|
| 142 |
+
|
| 143 |
+
def train_dataloader(self):
|
| 144 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
|
| 145 |
+
|
| 146 |
+
def val_dataloader(self):
|
| 147 |
+
return DataLoader(self.valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
| 148 |
+
|
| 149 |
+
def test_dataloader(self):
|
| 150 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class markerModel(pl.LightningModule):
|
| 154 |
+
def __init__(self, acc_model_name, don_model_name, lr, dropout, layer_features, loss_fn = "smooth", layer_limit = True, d_pretrained=True, p_pretrained=True):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.lr = lr
|
| 157 |
+
self.loss_fn = loss_fn
|
| 158 |
+
self.criterion = torch.nn.MSELoss()
|
| 159 |
+
self.criterion_smooth = torch.nn.SmoothL1Loss()
|
| 160 |
+
# self.sigmoid = nn.Sigmoid()
|
| 161 |
+
|
| 162 |
+
#-- Pretrained Model Setting
|
| 163 |
+
acc_config = AutoConfig.from_pretrained("seyonec/SMILES_BPE_PubChem_100k_shard00")
|
| 164 |
+
if d_pretrained is False:
|
| 165 |
+
self.d_model = RobertaModel(acc_config)
|
| 166 |
+
print('acceptor model without pretraining')
|
| 167 |
+
else:
|
| 168 |
+
self.d_model = RobertaModel.from_pretrained(acc_model_name, num_labels=2,
|
| 169 |
+
output_hidden_states=True,
|
| 170 |
+
output_attentions=True)
|
| 171 |
+
|
| 172 |
+
don_config = AutoConfig.from_pretrained("seyonec/SMILES_BPE_PubChem_100k_shard00")
|
| 173 |
+
|
| 174 |
+
if p_pretrained is False:
|
| 175 |
+
self.p_model = RobertaModel(don_config)
|
| 176 |
+
print('donor model without pretraining')
|
| 177 |
+
else:
|
| 178 |
+
self.p_model = RobertaModel.from_pretrained(don_model_name,
|
| 179 |
+
output_hidden_states=True,
|
| 180 |
+
output_attentions=True)
|
| 181 |
+
|
| 182 |
+
#-- Decoder Layer Setting
|
| 183 |
+
layers = []
|
| 184 |
+
firstfeature = self.d_model.config.hidden_size + self.p_model.config.hidden_size
|
| 185 |
+
for feature_idx in range(0, len(layer_features) - 1):
|
| 186 |
+
layers.append(nn.Linear(firstfeature, layer_features[feature_idx]))
|
| 187 |
+
firstfeature = layer_features[feature_idx]
|
| 188 |
+
|
| 189 |
+
if feature_idx is len(layer_features)-2:
|
| 190 |
+
layers.append(nn.ReLU())
|
| 191 |
+
else:
|
| 192 |
+
layers.append(nn.ReLU())
|
| 193 |
+
|
| 194 |
+
if dropout > 0:
|
| 195 |
+
layers.append(nn.Dropout(dropout))
|
| 196 |
+
|
| 197 |
+
layers.append(nn.Linear(firstfeature, layer_features[-1]))
|
| 198 |
+
|
| 199 |
+
self.decoder = nn.Sequential(*layers)
|
| 200 |
+
|
| 201 |
+
self.save_hyperparameters()
|
| 202 |
+
|
| 203 |
+
def forward(self, acc_inputs, don_inputs):
|
| 204 |
+
|
| 205 |
+
d_outputs = self.d_model(acc_inputs['input_ids'], acc_inputs['attention_mask'])
|
| 206 |
+
p_outputs = self.p_model(don_inputs['input_ids'], don_inputs['attention_mask'])
|
| 207 |
+
|
| 208 |
+
outs = torch.cat((d_outputs.last_hidden_state[:, 0], p_outputs.last_hidden_state[:, 0]), dim=1)
|
| 209 |
+
outs = self.decoder(outs)
|
| 210 |
+
|
| 211 |
+
return outs
|
| 212 |
+
|
| 213 |
+
def attention_output(self, acc_inputs, don_inputs):
|
| 214 |
+
|
| 215 |
+
d_outputs = self.d_model(acc_inputs['input_ids'], acc_inputs['attention_mask'])
|
| 216 |
+
p_outputs = self.p_model(don_inputs['input_ids'], don_inputs['attention_mask'])
|
| 217 |
+
|
| 218 |
+
outs = torch.cat((d_outputs.last_hidden_state[:, 0], p_outputs.last_hidden_state[:, 0]), dim=1)
|
| 219 |
+
outs = self.decoder(outs)
|
| 220 |
+
|
| 221 |
+
return d_outputs['attentions'], p_outputs['attentions'], outs
|
| 222 |
+
|
| 223 |
+
def training_step(self, batch, batch_idx):
|
| 224 |
+
|
| 225 |
+
acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
|
| 226 |
+
|
| 227 |
+
don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
|
| 228 |
+
|
| 229 |
+
labels = batch[4]
|
| 230 |
+
|
| 231 |
+
output = self(acc_inputs, don_inputs)
|
| 232 |
+
logits = output.squeeze(dim=1)
|
| 233 |
+
|
| 234 |
+
if self.loss_fn == 'MSE':
|
| 235 |
+
loss = self.criterion(logits, labels)
|
| 236 |
+
else:
|
| 237 |
+
loss = self.criterion_smooth(logits, labels)
|
| 238 |
+
|
| 239 |
+
self.log("train_loss", loss, on_step=False, on_epoch=True, logger=True)
|
| 240 |
+
# print("train_loss", loss)
|
| 241 |
+
return {"loss": loss}
|
| 242 |
+
|
| 243 |
+
def validation_step(self, batch, batch_idx):
|
| 244 |
+
acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
|
| 245 |
+
don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
|
| 246 |
+
labels = batch[4]
|
| 247 |
+
|
| 248 |
+
output = self(acc_inputs, don_inputs)
|
| 249 |
+
logits = output.squeeze(dim=1)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
if self.loss_fn == 'MSE':
|
| 253 |
+
loss = self.criterion(logits, labels)
|
| 254 |
+
else:
|
| 255 |
+
loss = self.criterion_smooth(logits, labels)
|
| 256 |
+
|
| 257 |
+
self.log("valid_loss", loss, on_step=False, on_epoch=True, logger=True)
|
| 258 |
+
# print("valid_loss", loss)
|
| 259 |
+
return {"logits": logits, "labels": labels}
|
| 260 |
+
|
| 261 |
+
def validation_step_end(self, outputs):
|
| 262 |
+
return {"logits": outputs['logits'], "labels": outputs['labels']}
|
| 263 |
+
|
| 264 |
+
def validation_epoch_end(self, outputs):
|
| 265 |
+
preds = self.convert_outputs_to_preds(outputs)
|
| 266 |
+
labels = torch.as_tensor(torch.cat([output['labels'] for output in outputs], dim=0), dtype=torch.int)
|
| 267 |
+
|
| 268 |
+
mae, mse, r2,r = self.log_score(preds, labels)
|
| 269 |
+
|
| 270 |
+
self.log("mae", mae, on_step=False, on_epoch=True, logger=True)
|
| 271 |
+
self.log("mse", mse, on_step=False, on_epoch=True, logger=True)
|
| 272 |
+
|
| 273 |
+
self.log("r2", r2, on_step=False, on_epoch=True, logger=True)
|
| 274 |
+
|
| 275 |
+
def test_step(self, batch, batch_idx):
|
| 276 |
+
acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
|
| 277 |
+
don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
|
| 278 |
+
labels = batch[4]
|
| 279 |
+
|
| 280 |
+
output = self(acc_inputs, don_inputs)
|
| 281 |
+
logits = output.squeeze(dim=1)
|
| 282 |
+
|
| 283 |
+
if self.loss_fn == 'MSE':
|
| 284 |
+
loss = self.criterion(logits, labels)
|
| 285 |
+
else:
|
| 286 |
+
loss = self.criterion_smooth(logits, labels)
|
| 287 |
+
|
| 288 |
+
self.log("test_loss", loss, on_step=False, on_epoch=True, logger=True)
|
| 289 |
+
return {"logits": logits, "labels": labels}
|
| 290 |
+
|
| 291 |
+
def test_step_end(self, outputs):
|
| 292 |
+
return {"logits": outputs['logits'], "labels": outputs['labels']}
|
| 293 |
+
|
| 294 |
+
def test_epoch_end(self, outputs):
|
| 295 |
+
preds = self.convert_outputs_to_preds(outputs)
|
| 296 |
+
labels = torch.as_tensor(torch.cat([output['labels'] for output in outputs], dim=0), dtype=torch.int)
|
| 297 |
+
|
| 298 |
+
mae, mse, r2,r = self.log_score(preds, labels)
|
| 299 |
+
|
| 300 |
+
self.log("mae", mae, on_step=False, on_epoch=True, logger=True)
|
| 301 |
+
self.log("mse", mse, on_step=False, on_epoch=True, logger=True)
|
| 302 |
+
self.log("r2", r2, on_step=False, on_epoch=True, logger=True)
|
| 303 |
+
self.log("r", r, on_step=False, on_epoch=True, logger=True)
|
| 304 |
+
def configure_optimizers(self):
|
| 305 |
+
|
| 306 |
+
param_optimizer = list(self.named_parameters())
|
| 307 |
+
|
| 308 |
+
no_decay = ["bias", "gamma", "beta"]
|
| 309 |
+
optimizer_grouped_parameters = [
|
| 310 |
+
{
|
| 311 |
+
"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
| 312 |
+
"weight_decay_rate": 0.0001
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
|
| 316 |
+
"weight_decay_rate": 0.0
|
| 317 |
+
},
|
| 318 |
+
]
|
| 319 |
+
optimizer = torch.optim.AdamW(
|
| 320 |
+
optimizer_grouped_parameters,
|
| 321 |
+
lr=self.lr,
|
| 322 |
+
)
|
| 323 |
+
return optimizer
|
| 324 |
+
|
| 325 |
+
def convert_outputs_to_preds(self, outputs):
|
| 326 |
+
logits = torch.cat([output['logits'] for output in outputs], dim=0)
|
| 327 |
+
return logits
|
| 328 |
+
|
| 329 |
+
def log_score(self, preds, labels):
|
| 330 |
+
y_pred = preds.detach().cpu().numpy()
|
| 331 |
+
y_label = labels.detach().cpu().numpy()
|
| 332 |
+
|
| 333 |
+
mae = mean_absolute_error(y_label, y_pred)
|
| 334 |
+
mse = mean_squared_error(y_label, y_pred)
|
| 335 |
+
r2=r2_score(y_label, y_pred)
|
| 336 |
+
r = pearsonr(y_label, y_pred)
|
| 337 |
+
print(f'\nmae : {mae}')
|
| 338 |
+
print(f'mse : {mse}')
|
| 339 |
+
print(f'r2 : {r2}')
|
| 340 |
+
print(f'r : {r}')
|
| 341 |
+
|
| 342 |
+
return mae, mse, r2, r
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def main_wandb(config=None):
|
| 346 |
+
try:
|
| 347 |
+
if config is not None:
|
| 348 |
+
wandb.init(config=config, project=project_name)
|
| 349 |
+
else:
|
| 350 |
+
wandb.init(settings=wandb.Settings(console='off'))
|
| 351 |
+
|
| 352 |
+
config = wandb.config
|
| 353 |
+
pl.seed_everything(seed=config.num_seed)
|
| 354 |
+
|
| 355 |
+
dm = markerDataModule(config.task_name, config.d_model_name, config.p_model_name,
|
| 356 |
+
config.num_workers, config.batch_size, config.prot_maxlength, config.traindata_rate)
|
| 357 |
+
dm.prepare_data()
|
| 358 |
+
dm.setup()
|
| 359 |
+
|
| 360 |
+
model_type = str(config.pretrained['chem'])+"To"+str(config.pretrained['prot'])
|
| 361 |
+
#model_logger = WandbLogger(project=project_name)
|
| 362 |
+
checkpoint_callback = ModelCheckpoint(f"{config.task_name}_{model_type}_{config.lr}_{config.num_seed}", save_top_k=1, monitor="mae", mode="max")
|
| 363 |
+
|
| 364 |
+
trainer = pl.Trainer(
|
| 365 |
+
max_epochs=config.max_epoch,
|
| 366 |
+
precision=16,
|
| 367 |
+
#logger=model_logger,
|
| 368 |
+
callbacks=[checkpoint_callback],
|
| 369 |
+
accelerator='cpu',log_every_n_steps=40
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
if config.model_mode == "train":
|
| 374 |
+
model = markerModel(config.d_model_name, config.p_model_name,
|
| 375 |
+
config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
|
| 376 |
+
model.train()
|
| 377 |
+
trainer.fit(model, datamodule=dm)
|
| 378 |
+
|
| 379 |
+
model.eval()
|
| 380 |
+
trainer.test(model, datamodule=dm)
|
| 381 |
+
|
| 382 |
+
else:
|
| 383 |
+
model = markerModel.load_from_checkpoint(config.load_checkpoint)
|
| 384 |
+
|
| 385 |
+
model.eval()
|
| 386 |
+
trainer.test(model, datamodule=dm)
|
| 387 |
+
|
| 388 |
+
except Exception as e:
|
| 389 |
+
print(e)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def main_default(config):
|
| 393 |
+
try:
|
| 394 |
+
config = DictX(config)
|
| 395 |
+
pl.seed_everything(seed=config.num_seed)
|
| 396 |
+
|
| 397 |
+
dm = markerDataModule(config.task_name, config.d_model_name, config.p_model_name,
|
| 398 |
+
config.num_workers, config.batch_size, config.traindata_rate)
|
| 399 |
+
|
| 400 |
+
dm.prepare_data()
|
| 401 |
+
dm.setup()
|
| 402 |
+
model_type = str(config.pretrained['chem'])+"To"+str(config.pretrained['prot'])
|
| 403 |
+
# model_logger = TensorBoardLogger("./log", name=f"{config.task_name}_{model_type}_{config.num_seed}")
|
| 404 |
+
checkpoint_callback = ModelCheckpoint(f"{config.task_name}_{model_type}_{config.lr}_{config.num_seed}", save_top_k=1, monitor="mse", mode="max")
|
| 405 |
+
|
| 406 |
+
trainer = pl.Trainer(
|
| 407 |
+
max_epochs=config.max_epoch,
|
| 408 |
+
precision= 32,
|
| 409 |
+
# logger=model_logger,
|
| 410 |
+
callbacks=[checkpoint_callback],
|
| 411 |
+
accelerator='cpu',log_every_n_steps=40
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
if config.model_mode == "train":
|
| 416 |
+
model = markerModel(config.d_model_name, config.p_model_name,
|
| 417 |
+
config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
|
| 418 |
+
|
| 419 |
+
model.train()
|
| 420 |
+
|
| 421 |
+
trainer.fit(model, datamodule=dm)
|
| 422 |
+
|
| 423 |
+
model.eval()
|
| 424 |
+
trainer.test(model, datamodule=dm)
|
| 425 |
+
|
| 426 |
+
else:
|
| 427 |
+
model = markerModel.load_from_checkpoint(config.load_checkpoint)
|
| 428 |
+
|
| 429 |
+
model.eval()
|
| 430 |
+
trainer.test(model, datamodule=dm)
|
| 431 |
+
except Exception as e:
|
| 432 |
+
print(e)
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
if __name__ == '__main__':
|
| 436 |
+
using_wandb = False
|
| 437 |
+
|
| 438 |
+
if using_wandb == True:
|
| 439 |
+
#-- hyper param config file Load --##
|
| 440 |
+
config = load_hparams('config/config_hparam.json')
|
| 441 |
+
project_name = config["name"]
|
| 442 |
+
|
| 443 |
+
main_wandb(config)
|
| 444 |
+
|
| 445 |
+
##-- wandb Sweep Hyper Param Tuning --##
|
| 446 |
+
# config = load_hparams('config/config_sweep_bindingDB.json')
|
| 447 |
+
# project_name = config["name"]
|
| 448 |
+
# sweep_id = wandb.sweep(config, project=project_name)
|
| 449 |
+
# wandb.agent(sweep_id, main_wandb)
|
| 450 |
+
|
| 451 |
+
else:
|
| 452 |
+
config = load_hparams('config/config_hparam.json')
|
| 453 |
+
|
| 454 |
+
main_default(config)
|
tool/dap/util/attention_flow.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import networkx as nx
|
| 2 |
+
import numpy as np
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
import itertools
|
| 9 |
+
import matplotlib as mpl
|
| 10 |
+
# import cugraph as cnx
|
| 11 |
+
|
| 12 |
+
rc={'font.size': 10, 'axes.labelsize': 10, 'legend.fontsize': 10.0,
|
| 13 |
+
'axes.titlesize': 32, 'xtick.labelsize': 20, 'ytick.labelsize': 16}
|
| 14 |
+
plt.rcParams.update(**rc)
|
| 15 |
+
mpl.rcParams['axes.linewidth'] = .5 #set the value globally
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def plot_attention_heatmap(att, s_position, t_positions, input_tokens):
|
| 19 |
+
|
| 20 |
+
cls_att = np.flip(att[:,s_position, t_positions], axis=0)
|
| 21 |
+
xticklb = list(itertools.compress(input_tokens, [i in t_positions for i in np.arange(len(input_tokens))]))
|
| 22 |
+
yticklb = [str(i) if i%2 ==0 else '' for i in np.arange(att.shape[0],0, -1)]
|
| 23 |
+
ax = sns.heatmap(cls_att, xticklabels=xticklb, yticklabels=yticklb, cmap="YlOrRd")
|
| 24 |
+
|
| 25 |
+
return ax
|
| 26 |
+
|
| 27 |
+
def convert_adjmat_tomats(adjmat, n_layers, l):
|
| 28 |
+
mats = np.zeros((n_layers,l,l))
|
| 29 |
+
|
| 30 |
+
for i in np.arange(n_layers):
|
| 31 |
+
mats[i] = adjmat[(i+1)*l:(i+2)*l,i*l:(i+1)*l]
|
| 32 |
+
|
| 33 |
+
return mats
|
| 34 |
+
|
| 35 |
+
def make_residual_attention(attentions):
|
| 36 |
+
all_attention = [att.detach().cpu().numpy() for att in attentions]
|
| 37 |
+
attentions_mat = np.asarray(all_attention)[:,0]
|
| 38 |
+
|
| 39 |
+
res_att_mat = attentions_mat.sum(axis=1)/attentions_mat.shape[1]
|
| 40 |
+
res_att_mat = res_att_mat + np.eye(res_att_mat.shape[1])[None,...]
|
| 41 |
+
res_att_mat = res_att_mat / res_att_mat.sum(axis=-1)[...,None]
|
| 42 |
+
|
| 43 |
+
return attentions_mat, res_att_mat
|
| 44 |
+
|
| 45 |
+
## -------------------------------------------------------- ##
|
| 46 |
+
## -- Make flow network (No Print Node - edge Connection)-- ##
|
| 47 |
+
## -------------------------------------------------------- ##
|
| 48 |
+
|
| 49 |
+
def make_flow_network(mat, input_tokens):
|
| 50 |
+
n_layers, length, _ = mat.shape
|
| 51 |
+
adj_mat = np.zeros(((n_layers+1)*length, (n_layers+1)*length))
|
| 52 |
+
labels_to_index = {}
|
| 53 |
+
for k in np.arange(length):
|
| 54 |
+
labels_to_index[str(k)+"_"+input_tokens[k]] = k
|
| 55 |
+
|
| 56 |
+
for i in np.arange(1,n_layers+1):
|
| 57 |
+
for k_f in np.arange(length):
|
| 58 |
+
index_from = (i)*length+k_f
|
| 59 |
+
label = "L"+str(i)+"_"+str(k_f)
|
| 60 |
+
labels_to_index[label] = index_from
|
| 61 |
+
for k_t in np.arange(length):
|
| 62 |
+
index_to = (i-1)*length+k_t
|
| 63 |
+
adj_mat[index_from][index_to] = mat[i-1][k_f][k_t]
|
| 64 |
+
|
| 65 |
+
net_graph=nx.from_numpy_matrix(adj_mat, create_using=nx.DiGraph())
|
| 66 |
+
for i in np.arange(adj_mat.shape[0]):
|
| 67 |
+
for j in np.arange(adj_mat.shape[1]):
|
| 68 |
+
nx.set_edge_attributes(net_graph, {(i,j): adj_mat[i,j]}, 'capacity')
|
| 69 |
+
|
| 70 |
+
return net_graph, labels_to_index
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def make_input_node(attention_mat, res_labels_to_index):
|
| 74 |
+
input_nodes = []
|
| 75 |
+
for key in res_labels_to_index:
|
| 76 |
+
if res_labels_to_index[key] < attention_mat.shape[-1]:
|
| 77 |
+
input_nodes.append(key)
|
| 78 |
+
|
| 79 |
+
return input_nodes
|
| 80 |
+
## ------------------------------------------------ ##
|
| 81 |
+
## -- Draw Attention flow node - Edge Connection -- ##
|
| 82 |
+
## ------------------------------------------------ ##
|
| 83 |
+
|
| 84 |
+
##-- networkx graph Initation and Calculation flow --##
|
| 85 |
+
def get_adjmat(mat, input_tokens):
|
| 86 |
+
n_layers, length, _ = mat.shape
|
| 87 |
+
adj_mat = np.zeros(((n_layers+1)*length, (n_layers+1)*length))
|
| 88 |
+
labels_to_index = {}
|
| 89 |
+
for k in np.arange(length):
|
| 90 |
+
labels_to_index[str(k)+"_"+input_tokens[k]] = k
|
| 91 |
+
|
| 92 |
+
for i in np.arange(1,n_layers+1):
|
| 93 |
+
for k_f in np.arange(length):
|
| 94 |
+
index_from = (i)*length+k_f
|
| 95 |
+
label = "L"+str(i)+"_"+str(k_f)
|
| 96 |
+
labels_to_index[label] = index_from
|
| 97 |
+
for k_t in np.arange(length):
|
| 98 |
+
index_to = (i-1)*length+k_t
|
| 99 |
+
adj_mat[index_from][index_to] = mat[i-1][k_f][k_t]
|
| 100 |
+
|
| 101 |
+
return adj_mat, labels_to_index
|
| 102 |
+
|
| 103 |
+
def draw_attention_graph(adjmat, labels_to_index, n_layers, length):
|
| 104 |
+
A = adjmat
|
| 105 |
+
net_graph=nx.from_numpy_matrix(A, create_using=nx.DiGraph())
|
| 106 |
+
for i in np.arange(A.shape[0]):
|
| 107 |
+
for j in np.arange(A.shape[1]):
|
| 108 |
+
nx.set_edge_attributes(net_graph, {(i,j): A[i,j]}, 'capacity')
|
| 109 |
+
|
| 110 |
+
pos = {}
|
| 111 |
+
label_pos = {}
|
| 112 |
+
for i in np.arange(n_layers+1):
|
| 113 |
+
for k_f in np.arange(length):
|
| 114 |
+
pos[i*length+k_f] = ((i+0.4)*2, length - k_f)
|
| 115 |
+
label_pos[i*length+k_f] = (i*2, length - k_f)
|
| 116 |
+
|
| 117 |
+
index_to_labels = {}
|
| 118 |
+
for key in labels_to_index:
|
| 119 |
+
index_to_labels[labels_to_index[key]] = key.split("_")[-1]
|
| 120 |
+
if labels_to_index[key] >= length:
|
| 121 |
+
index_to_labels[labels_to_index[key]] = ''
|
| 122 |
+
|
| 123 |
+
#plt.figure(1,figsize=(20,12))
|
| 124 |
+
nx.draw_networkx_nodes(net_graph,pos,node_color='green', labels=index_to_labels, node_size=50)
|
| 125 |
+
nx.draw_networkx_labels(net_graph,pos=label_pos, labels=index_to_labels, font_size=18)
|
| 126 |
+
|
| 127 |
+
all_weights = []
|
| 128 |
+
#4 a. Iterate through the graph nodes to gather all the weights
|
| 129 |
+
for (node1,node2,data) in net_graph.edges(data=True):
|
| 130 |
+
all_weights.append(data['weight']) #we'll use this when determining edge thickness
|
| 131 |
+
|
| 132 |
+
#4 b. Get unique weights
|
| 133 |
+
unique_weights = list(set(all_weights))
|
| 134 |
+
|
| 135 |
+
#4 c. Plot the edges - one by one!
|
| 136 |
+
for weight in unique_weights:
|
| 137 |
+
#4 d. Form a filtered list with just the weight you want to draw
|
| 138 |
+
weighted_edges = [(node1,node2) for (node1,node2,edge_attr) in net_graph.edges(data=True) if edge_attr['weight']==weight]
|
| 139 |
+
#4 e. I think multiplying by [num_nodes/sum(all_weights)] makes the graphs edges look cleaner
|
| 140 |
+
|
| 141 |
+
w = weight #(weight - min(all_weights))/(max(all_weights) - min(all_weights))
|
| 142 |
+
width = w
|
| 143 |
+
nx.draw_networkx_edges(net_graph,pos,edgelist=weighted_edges,width=width, edge_color='darkblue')
|
| 144 |
+
|
| 145 |
+
return net_graph
|
| 146 |
+
|
| 147 |
+
def compute_flows(G, labels_to_index, input_nodes, length):
|
| 148 |
+
number_of_nodes = len(labels_to_index)
|
| 149 |
+
flow_values=np.zeros((number_of_nodes,number_of_nodes))
|
| 150 |
+
for key in tqdm(labels_to_index, desc="flow algorithms", total=len(labels_to_index)):
|
| 151 |
+
if key not in input_nodes:
|
| 152 |
+
current_layer = int(labels_to_index[key] / length)
|
| 153 |
+
pre_layer = current_layer - 1
|
| 154 |
+
u = labels_to_index[key]
|
| 155 |
+
for inp_node_key in input_nodes:
|
| 156 |
+
v = labels_to_index[inp_node_key]
|
| 157 |
+
flow_value = nx.maximum_flow_value(G,u,v, flow_func=nx.algorithms.flow.edmonds_karp)
|
| 158 |
+
# flow_value = cnx
|
| 159 |
+
flow_values[u][pre_layer*length+v ] = flow_value
|
| 160 |
+
flow_values[u] /= flow_values[u].sum()
|
| 161 |
+
|
| 162 |
+
return flow_values
|
| 163 |
+
|
| 164 |
+
def compute_node_flow(G, labels_to_index, input_nodes, output_nodes,length):
|
| 165 |
+
number_of_nodes = len(labels_to_index)
|
| 166 |
+
flow_values=np.zeros((number_of_nodes,number_of_nodes))
|
| 167 |
+
for key in output_nodes:
|
| 168 |
+
if key not in input_nodes:
|
| 169 |
+
current_layer = int(labels_to_index[key] / length)
|
| 170 |
+
pre_layer = current_layer - 1
|
| 171 |
+
u = labels_to_index[key]
|
| 172 |
+
for inp_node_key in input_nodes:
|
| 173 |
+
v = labels_to_index[inp_node_key]
|
| 174 |
+
flow_value = nx.maximum_flow_value(G,u,v, flow_func=nx.algorithms.flow.edmonds_karp)
|
| 175 |
+
flow_values[u][pre_layer*length+v ] = flow_value
|
| 176 |
+
flow_values[u] /= flow_values[u].sum()
|
| 177 |
+
|
| 178 |
+
return flow_values
|
| 179 |
+
|
| 180 |
+
def compute_joint_attention(att_mat, add_residual=True):
|
| 181 |
+
if add_residual:
|
| 182 |
+
residual_att = np.eye(att_mat.shape[1])[None,...]
|
| 183 |
+
aug_att_mat = att_mat + residual_att
|
| 184 |
+
aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[...,None]
|
| 185 |
+
else:
|
| 186 |
+
aug_att_mat = att_mat
|
| 187 |
+
|
| 188 |
+
joint_attentions = np.zeros(aug_att_mat.shape)
|
| 189 |
+
|
| 190 |
+
layers = joint_attentions.shape[0]
|
| 191 |
+
joint_attentions[0] = aug_att_mat[0]
|
| 192 |
+
for i in np.arange(1,layers):
|
| 193 |
+
joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i-1])
|
| 194 |
+
|
| 195 |
+
return joint_attentions
|
tool/dap/util/attention_plot.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
import plotly.express as px
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
|
| 6 |
+
def make_attention_table(att, tokens, numb, token_idx = 0, layerNumb = -1):
|
| 7 |
+
token_att = att[layerNumb, token_idx, range(1, len(tokens))]
|
| 8 |
+
|
| 9 |
+
token_label=[]
|
| 10 |
+
token_numb=[]
|
| 11 |
+
for idx, token in enumerate(tokens[1:]) :
|
| 12 |
+
token_label.append(f"<b>{token}</b>")
|
| 13 |
+
token_numb.append(f"{idx}")
|
| 14 |
+
|
| 15 |
+
pair = list(zip(token_numb, token_att))
|
| 16 |
+
|
| 17 |
+
df = pd.DataFrame(pair, columns=["Amino acid", "Attention rate"])
|
| 18 |
+
df.to_csv(f"amino_acid_seq_attention_{numb}.csv", index=None)
|
| 19 |
+
|
| 20 |
+
top3_idx = sorted(range(len(token_att)), key=lambda i: token_att[i], reverse=True)[:3]
|
| 21 |
+
|
| 22 |
+
colors = ['cornflowerblue', ] * len(token_numb)
|
| 23 |
+
|
| 24 |
+
for i in top3_idx:
|
| 25 |
+
colors[i] = 'crimson'
|
| 26 |
+
|
| 27 |
+
fig = go.Figure(data=[go.Bar(
|
| 28 |
+
x=df["Amino acid"],
|
| 29 |
+
y=df["Attention rate"],
|
| 30 |
+
# range_y=[min(token_att), max(token_att)],
|
| 31 |
+
marker_color=colors # marker color can be a single color value or an iterable
|
| 32 |
+
)])
|
| 33 |
+
|
| 34 |
+
# fig = px.histogram(df, x="Amino acid", y="Attention rate", range_y=[min(token_att), max(token_att)])
|
| 35 |
+
|
| 36 |
+
fig.update_layout(plot_bgcolor="white")
|
| 37 |
+
fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
|
| 38 |
+
fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
|
| 39 |
+
fig.update_layout(title={'text': "<b>Attention rate of amino acid sequence token</b>",
|
| 40 |
+
'font':{'size':40},
|
| 41 |
+
'y': 0.96,
|
| 42 |
+
'x': 0.5,
|
| 43 |
+
'xanchor': 'center',
|
| 44 |
+
'yanchor': 'top'},
|
| 45 |
+
|
| 46 |
+
xaxis=dict(tickmode='array',
|
| 47 |
+
tickvals=token_numb,
|
| 48 |
+
ticktext=token_label
|
| 49 |
+
),
|
| 50 |
+
|
| 51 |
+
xaxis_title={'text': "Amino acid sequence",
|
| 52 |
+
'font':{'size':30}},
|
| 53 |
+
yaxis_title={'text': "Attention rate",
|
| 54 |
+
'font':{'size':30}},
|
| 55 |
+
|
| 56 |
+
font=dict(family="Calibri, monospace",
|
| 57 |
+
size=17
|
| 58 |
+
))
|
| 59 |
+
|
| 60 |
+
fig.write_image(f'figures/Amino_acid_seq_{numb}.png', width=1.5*1200, height=0.75*1200, scale=2)
|
| 61 |
+
fig.show()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def read_attention():
|
| 65 |
+
df = pd.read_csv("../amino_acid_seq_attention.csv")
|
| 66 |
+
# d_flow_values = np.asarray(d_read_flow_values)
|
| 67 |
+
|
| 68 |
+
fig = px.bar(df, x="Amino acid", y="Attention rate", range_y=[min(df["Attention rate"]), max(df["Attention rate"])])
|
| 69 |
+
|
| 70 |
+
fig.update_layout(plot_bgcolor="white")
|
| 71 |
+
fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
|
| 72 |
+
fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
|
| 73 |
+
fig.update_layout(title={'text': "<b>Attention rate of amino acid sequence token</b>",
|
| 74 |
+
'font':{'size':40},
|
| 75 |
+
'y': 0.96,
|
| 76 |
+
'x': 0.5,
|
| 77 |
+
'xanchor': 'center',
|
| 78 |
+
'yanchor': 'top'},
|
| 79 |
+
|
| 80 |
+
xaxis_title={'text': "Amino acid sequence",
|
| 81 |
+
'font':{'size':30}},
|
| 82 |
+
yaxis_title={'text': "Attention rate",
|
| 83 |
+
'font':{'size':30}},
|
| 84 |
+
|
| 85 |
+
font=dict(family="Calibri, monospace",
|
| 86 |
+
size=17
|
| 87 |
+
))
|
| 88 |
+
|
| 89 |
+
fig.write_image('figures/Amino_acid_seq.png', width=1.5*1200, height=0.75*1200, scale=2)
|
| 90 |
+
fig.show()
|
| 91 |
+
|
| 92 |
+
if __name__ == '__main__':
|
| 93 |
+
read_attention()
|
tool/dap/util/boxplot.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from scipy import stats
|
| 5 |
+
import plotly.express as px
|
| 6 |
+
|
| 7 |
+
from plotly.subplots import make_subplots
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
|
| 10 |
+
ROC = 1
|
| 11 |
+
PR = 2
|
| 12 |
+
|
| 13 |
+
def add_p_value_annotation(fig, array_columns, subplot=None, _format=dict(interline=0.03, text_height=1.03, color='black')):
|
| 14 |
+
''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
|
| 15 |
+
|
| 16 |
+
Parameters:
|
| 17 |
+
----------
|
| 18 |
+
fig: figure
|
| 19 |
+
plotly boxplot figure
|
| 20 |
+
array_columns: np.array
|
| 21 |
+
array of which columns to compare
|
| 22 |
+
e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
|
| 23 |
+
subplot: None or int
|
| 24 |
+
specifies if the figures has subplots and what subplot to add the notation to
|
| 25 |
+
_format: dict
|
| 26 |
+
format characteristics for the lines
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
-------
|
| 30 |
+
fig: figure
|
| 31 |
+
figure with the added notation
|
| 32 |
+
'''
|
| 33 |
+
# Specify in what y_range to plot for each pair of columns
|
| 34 |
+
y_range = np.zeros([len(array_columns), 2])
|
| 35 |
+
for i in range(len(array_columns)):
|
| 36 |
+
y_range[i] = [1.03+i*_format['interline'], 1.04+i*_format['interline']]
|
| 37 |
+
|
| 38 |
+
# Get values from figure
|
| 39 |
+
fig_dict = fig.to_dict()
|
| 40 |
+
|
| 41 |
+
# Get indices if working with subplots
|
| 42 |
+
if subplot:
|
| 43 |
+
if subplot == 1:
|
| 44 |
+
subplot_str = ''
|
| 45 |
+
else:
|
| 46 |
+
subplot_str =str(subplot)
|
| 47 |
+
indices = [] #Change the box index to the indices of the data for that subplot
|
| 48 |
+
for index, data in enumerate(fig_dict['data']):
|
| 49 |
+
#print(index, data['xaxis'], 'x' + subplot_str)
|
| 50 |
+
if data['xaxis'] == 'x' + subplot_str:
|
| 51 |
+
indices = np.append(indices, index)
|
| 52 |
+
indices = [int(i) for i in indices]
|
| 53 |
+
print((indices))
|
| 54 |
+
else:
|
| 55 |
+
subplot_str = ''
|
| 56 |
+
|
| 57 |
+
# Print the p-values
|
| 58 |
+
for index, column_pair in enumerate(array_columns):
|
| 59 |
+
if subplot:
|
| 60 |
+
data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
|
| 61 |
+
else:
|
| 62 |
+
data_pair = column_pair
|
| 63 |
+
|
| 64 |
+
# Mare sure it is selecting the data and subplot you want
|
| 65 |
+
#print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
|
| 66 |
+
#print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])
|
| 67 |
+
|
| 68 |
+
# Get the p-value
|
| 69 |
+
pvalue = stats.ttest_ind(
|
| 70 |
+
fig_dict['data'][data_pair[0]]['y'],
|
| 71 |
+
fig_dict['data'][data_pair[1]]['y'],
|
| 72 |
+
equal_var=False,
|
| 73 |
+
)[1]
|
| 74 |
+
if pvalue >= 0.05:
|
| 75 |
+
symbol = 'ns'
|
| 76 |
+
elif pvalue >= 0.01:
|
| 77 |
+
symbol = '*'
|
| 78 |
+
elif pvalue >= 0.001:
|
| 79 |
+
symbol = '**'
|
| 80 |
+
else:
|
| 81 |
+
symbol = '***'
|
| 82 |
+
# Vertical line
|
| 83 |
+
fig.add_shape(type="line",
|
| 84 |
+
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
| 85 |
+
x0=column_pair[0], y0=y_range[index][0],
|
| 86 |
+
x1=column_pair[0], y1=y_range[index][1],
|
| 87 |
+
line=dict(color=_format['color'], width=1.5,)
|
| 88 |
+
)
|
| 89 |
+
# Horizontal line
|
| 90 |
+
fig.add_shape(type="line",
|
| 91 |
+
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
| 92 |
+
x0=column_pair[0], y0=y_range[index][1],
|
| 93 |
+
x1=column_pair[1], y1=y_range[index][1],
|
| 94 |
+
line=dict(color=_format['color'], width=1.5,)
|
| 95 |
+
)
|
| 96 |
+
# Vertical line
|
| 97 |
+
fig.add_shape(type="line",
|
| 98 |
+
xref="x"+subplot_str, yref="y"+subplot_str+" domain",
|
| 99 |
+
x0=column_pair[1], y0=y_range[index][0],
|
| 100 |
+
x1=column_pair[1], y1=y_range[index][1],
|
| 101 |
+
line=dict(color=_format['color'], width=1.5,)
|
| 102 |
+
)
|
| 103 |
+
## add text at the correct x, y coordinates
|
| 104 |
+
## for bars, there is a direct mapping from the bar number to 0, 1, 2...
|
| 105 |
+
fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
|
| 106 |
+
x=(column_pair[0] + column_pair[1])/2,
|
| 107 |
+
y=y_range[index][1]*_format['text_height'],
|
| 108 |
+
showarrow=False,
|
| 109 |
+
text=symbol,
|
| 110 |
+
textangle=0,
|
| 111 |
+
xref="x"+subplot_str,
|
| 112 |
+
yref="y"+subplot_str+" domain"
|
| 113 |
+
))
|
| 114 |
+
return fig
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def box_plot(df):
|
| 118 |
+
|
| 119 |
+
fig = px.box(df, x = 'Task_name', y='test_auroc', color="Model")
|
| 120 |
+
|
| 121 |
+
fig.update_layout(plot_bgcolor="white")
|
| 122 |
+
fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
|
| 123 |
+
fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
|
| 124 |
+
fig.update_layout(title={'text': "<b>ROC-AUC score distribution</b>",
|
| 125 |
+
'font':{'size':40},
|
| 126 |
+
'y': 0.96,
|
| 127 |
+
'x': 0.5,
|
| 128 |
+
'xanchor': 'center',
|
| 129 |
+
'yanchor': 'top'},
|
| 130 |
+
|
| 131 |
+
xaxis_title={'text': "Datasets",
|
| 132 |
+
'font':{'size':30}},
|
| 133 |
+
yaxis_title={'text': "ROC-AUC",
|
| 134 |
+
'font':{'size':30}},
|
| 135 |
+
|
| 136 |
+
font=dict(family="Calibri, monospace",
|
| 137 |
+
size=17
|
| 138 |
+
))
|
| 139 |
+
|
| 140 |
+
fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1)
|
| 141 |
+
|
| 142 |
+
fig.write_image('../figures/box_plot_integration.png', width=1.5*1200, height=0.75*1200, scale=2)
|
| 143 |
+
fig.show()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def go_box_plot(df, metric = ROC):
|
| 148 |
+
dataset_list = ['BIOSNAP', 'DAVIS', 'BindingDB']
|
| 149 |
+
model_list = ['LR', 'DNN', 'GNN-CPI', 'DeepDTI', 'DeepDTA', 'DeepConv-DTI', 'Moltrans', 'ours']
|
| 150 |
+
clr_list = ['red', 'orange', 'green', 'indianred', 'lightseagreen', 'goldenrod', 'magenta', 'blue']
|
| 151 |
+
|
| 152 |
+
if metric == ROC:
|
| 153 |
+
# fig_title = "<b>ROC-AUC score distribution</b>"
|
| 154 |
+
file_title = "boxplot_auroc.png"
|
| 155 |
+
select_metric = "test_auroc"
|
| 156 |
+
else:
|
| 157 |
+
# fig_title = "<b>PR-AUC score distribution</b>"
|
| 158 |
+
file_title = "boxplot_auprc.png"
|
| 159 |
+
select_metric = "test_auprc"
|
| 160 |
+
|
| 161 |
+
fig = make_subplots(rows=1, cols=3, subplot_titles=[c for c in dataset_list])
|
| 162 |
+
|
| 163 |
+
groups = df.groupby(df.Task_name)
|
| 164 |
+
Legand = True
|
| 165 |
+
|
| 166 |
+
for dataset_idx, dataset in enumerate(dataset_list):
|
| 167 |
+
df_modelgroup = groups.get_group(dataset)
|
| 168 |
+
model_groups = df_modelgroup.groupby(df_modelgroup.Model)
|
| 169 |
+
if dataset_idx != 0:
|
| 170 |
+
Legand = False
|
| 171 |
+
for model_idx, model in enumerate(model_list):
|
| 172 |
+
df_data = model_groups.get_group(model)
|
| 173 |
+
fig.append_trace(go.Box(y=df_data[select_metric],
|
| 174 |
+
name=model,
|
| 175 |
+
marker_color=clr_list[model_idx],
|
| 176 |
+
showlegend = Legand
|
| 177 |
+
),
|
| 178 |
+
row=1,
|
| 179 |
+
col=dataset_idx+1)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# fig.update_layout(title={'text': fig_title,
|
| 185 |
+
# 'font':{'size':25},
|
| 186 |
+
# 'y': 0.98,
|
| 187 |
+
# 'x': 0.46,
|
| 188 |
+
# 'xanchor': 'center',
|
| 189 |
+
# 'yanchor': 'top'})
|
| 190 |
+
|
| 191 |
+
# fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1)
|
| 192 |
+
# fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=2)
|
| 193 |
+
# fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=3)
|
| 194 |
+
|
| 195 |
+
fig.write_image(f'../figures/{file_title}', width=1.5*1200, height=0.75*1200, scale=2)
|
| 196 |
+
fig.show()
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == '__main__':
|
| 200 |
+
df = pd.read_csv("../dataset/wandb_export_boxplotdata.csv")
|
| 201 |
+
box_plot(df)
|
tool/dap/util/data/bindingdb_kd.tab
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b72a38ae07a75d5d4c269d2776b6e62e0edde29ff7cf8a323158c08951f808d1
|
| 3 |
+
size 54432102
|
tool/dap/util/data/davis.tab
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d4c6809dcb7c5da2b91a32d594d6935b75484940bde4d18055eb5e1059262f4
|
| 3 |
+
size 21376712
|
tool/dap/util/emetric.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def get_cindex(Y, P):
|
| 4 |
+
summ = 0
|
| 5 |
+
pair = 0
|
| 6 |
+
|
| 7 |
+
for i in range(1, len(Y)):
|
| 8 |
+
for j in range(0, i):
|
| 9 |
+
if i is not j:
|
| 10 |
+
if(Y[i] > Y[j]):
|
| 11 |
+
pair +=1
|
| 12 |
+
summ += 1* (P[i] > P[j]) + 0.5 * (P[i] == P[j])
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
if pair is not 0:
|
| 16 |
+
return summ/pair
|
| 17 |
+
else:
|
| 18 |
+
return 0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def r_squared_error(y_obs,y_pred):
|
| 22 |
+
y_obs = np.array(y_obs)
|
| 23 |
+
y_pred = np.array(y_pred)
|
| 24 |
+
y_obs_mean = [np.mean(y_obs) for y in y_obs]
|
| 25 |
+
y_pred_mean = [np.mean(y_pred) for y in y_pred]
|
| 26 |
+
|
| 27 |
+
mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean))
|
| 28 |
+
mult = mult * mult
|
| 29 |
+
|
| 30 |
+
y_obs_sq = sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
|
| 31 |
+
y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean) )
|
| 32 |
+
|
| 33 |
+
return mult / float(y_obs_sq * y_pred_sq)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_k(y_obs,y_pred):
|
| 37 |
+
y_obs = np.array(y_obs)
|
| 38 |
+
y_pred = np.array(y_pred)
|
| 39 |
+
|
| 40 |
+
return sum(y_obs*y_pred) / float(sum(y_pred*y_pred))
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def squared_error_zero(y_obs,y_pred):
|
| 44 |
+
k = get_k(y_obs,y_pred)
|
| 45 |
+
|
| 46 |
+
y_obs = np.array(y_obs)
|
| 47 |
+
y_pred = np.array(y_pred)
|
| 48 |
+
y_obs_mean = [np.mean(y_obs) for y in y_obs]
|
| 49 |
+
upp = sum((y_obs - (k*y_pred)) * (y_obs - (k* y_pred)))
|
| 50 |
+
down= sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
|
| 51 |
+
|
| 52 |
+
return 1 - (upp / float(down))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_rm2(ys_orig,ys_line):
|
| 56 |
+
r2 = r_squared_error(ys_orig, ys_line)
|
| 57 |
+
r02 = squared_error_zero(ys_orig, ys_line)
|
| 58 |
+
|
| 59 |
+
return r2 * (1 - np.sqrt(np.absolute((r2*r2)-(r02*r02))))
|
tool/dap/util/load_dataset.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tdc.multi_pred import DTI
|
| 2 |
+
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
if __name__ == '__main__':
|
| 7 |
+
bindingDB_data = DTI(name = 'BindingDB_Kd')
|
| 8 |
+
davis_data = DTI(name = 'DAVIS')
|
| 9 |
+
|
| 10 |
+
bindingDB_data.harmonize_affinities(mode = 'max_affinity')
|
| 11 |
+
|
| 12 |
+
bindingDB_data.convert_to_log(form = 'binding')
|
| 13 |
+
davis_data.convert_to_log(form = 'binding')
|
| 14 |
+
|
| 15 |
+
split_bindingDB = bindingDB_data.get_split()
|
| 16 |
+
split_davis = davis_data.get_split()
|
| 17 |
+
|
| 18 |
+
dataset_list = ["train", "valid", "test"]
|
| 19 |
+
for dataset_type in dataset_list:
|
| 20 |
+
df_bindingDB = pd.DataFrame(split_bindingDB[dataset_type])
|
| 21 |
+
df_davis = pd.DataFrame(split_davis[dataset_type])
|
| 22 |
+
|
| 23 |
+
df_bindingDB.to_csv(f"../dataset_kd/bindingDB_{dataset_type}.csv", index=False)
|
| 24 |
+
df_davis.to_csv(f"../dataset_kd/davis_{dataset_type}.csv", index=False)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
Y_bindingDB = np.array(df_bindingDB.Y)
|
| 28 |
+
Y_davis = np.array(df_davis.Y)
|
| 29 |
+
|
| 30 |
+
Y_davis_log = [np.log10(Y_davis)]
|
| 31 |
+
|
| 32 |
+
|
tool/dap/util/make_external_validation.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == '__main__':
|
| 6 |
+
smiles = pd.read_csv("../dataset/external_smiles.csv")
|
| 7 |
+
ass = pd.read_csv("../dataset/external_aas.csv")
|
| 8 |
+
|
| 9 |
+
smiles_data = list(np.array(smiles['smiles']))
|
| 10 |
+
smiles_label = list(np.array(smiles['label'].tolist()))
|
| 11 |
+
smiles_label = [x.split() for x in smiles_label]
|
| 12 |
+
|
| 13 |
+
ass_data = list(np.array(ass['aas']))
|
| 14 |
+
cyp_type = list(np.array(ass['CYP_type']))
|
| 15 |
+
|
| 16 |
+
external_dataset = []
|
| 17 |
+
for smiles_idx in range(0, len(smiles_data)):
|
| 18 |
+
for ass_idx in range(0, len(ass_data)):
|
| 19 |
+
|
| 20 |
+
external_data = [smiles_data[smiles_idx], ass_data[ass_idx], cyp_type[ass_idx]]
|
| 21 |
+
external_dataset.append(external_data)
|
| 22 |
+
|
| 23 |
+
df = pd.DataFrame(external_dataset, columns=['smiles', 'aas', 'CYP_type'])
|
| 24 |
+
df.to_csv('../dataset/external_dataset.csv', index=False)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
print(smiles['smiles'][0])
|
| 28 |
+
print(ass['CYP_type'][0])
|
tool/dap/util/utils.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, copy
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
class DictX(dict):
|
| 7 |
+
def __getattr__(self, key):
|
| 8 |
+
try:
|
| 9 |
+
return self[key]
|
| 10 |
+
except KeyError as k:
|
| 11 |
+
raise AttributeError(k)
|
| 12 |
+
|
| 13 |
+
def __setattr__(self, key, value):
|
| 14 |
+
self[key] = value
|
| 15 |
+
|
| 16 |
+
def __delattr__(self, key):
|
| 17 |
+
try:
|
| 18 |
+
del self[key]
|
| 19 |
+
except KeyError as k:
|
| 20 |
+
raise AttributeError(k)
|
| 21 |
+
|
| 22 |
+
def __repr__(self):
|
| 23 |
+
return '<DictX ' + dict.__repr__(self) + '>'
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_hparams(file_path):
|
| 27 |
+
hparams = EasyDict()
|
| 28 |
+
with open(file_path, 'r') as f:
|
| 29 |
+
hparams = json.load(f)
|
| 30 |
+
return hparams
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
|
| 34 |
+
oldModuleList = model.encoder.layer
|
| 35 |
+
newModuleList = nn.ModuleList()
|
| 36 |
+
|
| 37 |
+
# Now iterate over all layers, only keepign only the relevant layers.
|
| 38 |
+
for i in range(num_layers_to_keep):
|
| 39 |
+
newModuleList.append(oldModuleList[i])
|
| 40 |
+
|
| 41 |
+
# create a copy of the model, modify it with the new list, and return
|
| 42 |
+
copyOfModel = copy.deepcopy(model)
|
| 43 |
+
copyOfModel.encoder.layer = newModuleList
|
| 44 |
+
|
| 45 |
+
return copyOfModel
|
tool/dataset.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tool/deepacceptor/RF.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Mon Sep 4 10:38:59 2023
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from sklearn.metrics import confusion_matrix
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
from rdkit.Chem import AllChem
|
| 13 |
+
from sklearn.datasets import make_blobs
|
| 14 |
+
import json
|
| 15 |
+
import numpy as np
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
from scipy import sparse
|
| 19 |
+
from sklearn.metrics import median_absolute_error,r2_score, mean_absolute_error,mean_squared_error
|
| 20 |
+
import pickle
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
import pandas as pd
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
from rdkit import Chem
|
| 26 |
+
|
| 27 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def split_string(string):
|
| 31 |
+
result = []
|
| 32 |
+
|
| 33 |
+
for char in string:
|
| 34 |
+
|
| 35 |
+
result.append(char)
|
| 36 |
+
|
| 37 |
+
return result
|
| 38 |
+
def main(sm):
|
| 39 |
+
|
| 40 |
+
inchis = list([sm])
|
| 41 |
+
rts = list([0])
|
| 42 |
+
|
| 43 |
+
smiles, targets,features = [], [],[]
|
| 44 |
+
for i, inc in enumerate((inchis)):
|
| 45 |
+
mol = Chem.MolFromSmiles(inc)
|
| 46 |
+
if mol is None:
|
| 47 |
+
continue
|
| 48 |
+
else:
|
| 49 |
+
smi =AllChem. GetMorganFingerprintAsBitVect(mol,3,2048)
|
| 50 |
+
smi = smi.ToBitString()
|
| 51 |
+
a = split_string(smi)
|
| 52 |
+
a = np.array(a)
|
| 53 |
+
#smi = Chem.MolToSmiles(mol)
|
| 54 |
+
features.append(a)
|
| 55 |
+
targets.append(rts[i])
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
features = np.asarray(features)
|
| 60 |
+
targets = np.asarray(targets)
|
| 61 |
+
X_test=features
|
| 62 |
+
Y_test=targets
|
| 63 |
+
n_features=10
|
| 64 |
+
|
| 65 |
+
model = RandomForestRegressor(n_estimators=500)
|
| 66 |
+
|
| 67 |
+
load_model = pickle.load(open(r"tool\deepacceptor\deepacceptor.pkl","rb"))
|
| 68 |
+
Y_predict = load_model.predict(X_test)
|
| 69 |
+
|
| 70 |
+
return Y_predict
|
tool/deepacceptor/deepacceptor.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11753f5d925de9fdf0ff0afc05204bcb54ad26753f13e02e63696fa3c65e8029
|
| 3 |
+
size 28084161
|
tool/deepacceptor/dict.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[" ", "C", "1", "=", "(", "2", "F", ")", "3", "4", "5", "#", "N", "S", "/", "\\", "O", "6", "7", "8", "9", "%", "0", "[", "Se", "]", "Cl", "Br", "B", ".", "P", "I", "@", "H"]
|
tool/deepdonor/pm.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:155cb847cef95d069044c425c409ca8daff368bd2f3310f43965b6c65a2914e2
|
| 3 |
+
size 8220594
|
tool/deepdonor/sm.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b27366195bfd2fcb74dbe20ccc1243e6297ee1f5c272947613f875141fdceb2
|
| 3 |
+
size 31982999
|
tool/graphconverter.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Thu Nov 7 15:38:35 2024
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from DECIMER import predict_SMILES
|
| 9 |
+
from langchain.tools import BaseTool
|
| 10 |
+
|
| 11 |
+
class graphconverter(BaseTool):
|
| 12 |
+
name: str = "graphconverter"
|
| 13 |
+
description: str = (
|
| 14 |
+
"Input graph path , returns SMILES."
|
| 15 |
+
"It was used to convert graph/figure/image containing molecule to SMILES"
|
| 16 |
+
)
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
def _run(self, paths: str) -> str:
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
SMILES = predict_SMILES(paths)
|
| 24 |
+
except:
|
| 25 |
+
return 'Please recheck the graph path'
|
| 26 |
+
return SMILES
|
| 27 |
+
|
| 28 |
+
async def _arun(self, smiles: str) -> str:
|
| 29 |
+
"""Use the tool asynchronously."""
|
| 30 |
+
raise NotImplementedError()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
tool/orbital.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Wed Oct 30 09:14:55 2024
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from sklearn.metrics import confusion_matrix
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
from rdkit.Chem import AllChem
|
| 13 |
+
from sklearn.datasets import make_blobs
|
| 14 |
+
import json
|
| 15 |
+
import numpy as np
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
from scipy import sparse
|
| 19 |
+
from sklearn.metrics import median_absolute_error,r2_score, mean_absolute_error,mean_squared_error
|
| 20 |
+
from langchain.tools import BaseTool
|
| 21 |
+
import pandas as pd
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
from rdkit import Chem
|
| 24 |
+
import pickle
|
| 25 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def split_string(string):
|
| 29 |
+
|
| 30 |
+
result = []
|
| 31 |
+
|
| 32 |
+
for char in string:
|
| 33 |
+
|
| 34 |
+
result.append(char)
|
| 35 |
+
|
| 36 |
+
return result
|
| 37 |
+
|
| 38 |
+
def main(sm):
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
inchis = list([sm])
|
| 42 |
+
rts = list([0])
|
| 43 |
+
|
| 44 |
+
smiles, targets,features = [], [],[]
|
| 45 |
+
for i, inc in enumerate(inchis):
|
| 46 |
+
mol = Chem.MolFromSmiles(inc)
|
| 47 |
+
if mol is None:
|
| 48 |
+
continue
|
| 49 |
+
else:
|
| 50 |
+
smi =AllChem. GetMorganFingerprintAsBitVect(mol,1024)
|
| 51 |
+
smi = smi.ToBitString()
|
| 52 |
+
a = split_string(smi)
|
| 53 |
+
a = np.array(a)
|
| 54 |
+
#smi = Chem.MolToSmiles(mol)
|
| 55 |
+
features.append(a)
|
| 56 |
+
targets.append(rts[i])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
features = np.asarray(features)
|
| 61 |
+
targets = np.asarray(targets)
|
| 62 |
+
X_test=features
|
| 63 |
+
Y_test=targets
|
| 64 |
+
n_features=10
|
| 65 |
+
|
| 66 |
+
model = RandomForestRegressor(n_estimators=100)
|
| 67 |
+
load_homo = pickle.load(open(r"tool/orbital/homo.pkl", 'rb'))
|
| 68 |
+
load_lumo = pickle.load(open(r"tool/orbital/lumo.pkl", 'rb'))
|
| 69 |
+
# model = load_model('C:/Users/sunjinyu/Desktop/FingerID Reference/drug-likeness/CNN/single_model.h5')
|
| 70 |
+
Y_homo= load_homo.predict(X_test)
|
| 71 |
+
Y_lumo = load_lumo.predict(X_test)
|
| 72 |
+
homo = float(Y_homo)
|
| 73 |
+
lumo = float(Y_lumo)
|
| 74 |
+
return homo, lumo
|
| 75 |
+
|
| 76 |
+
class homolumo_predictor(BaseTool):
|
| 77 |
+
name: str = "homolumo_predictor"
|
| 78 |
+
description: str = (
|
| 79 |
+
"Input SMILES , returns the HOMO/LUMO (Highest Occupied Molecular Orbital (HOMO) \
|
| 80 |
+
and Lowest Unoccupied Molecular Orbital)."
|
| 81 |
+
)
|
| 82 |
+
def __init__(self):
|
| 83 |
+
super().__init__()
|
| 84 |
+
def _run(self, smiles: str) -> str:
|
| 85 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 86 |
+
if mol is None:
|
| 87 |
+
return "Invalid SMILES string"
|
| 88 |
+
Y_homo, Y_lumo = main( str(smiles) )
|
| 89 |
+
return f"The HOMO is predicted to be {'{:.2f}'.format(Y_homo)} eV , the LUMO is predicted to be {'{:.2f}'.format(Y_lumo)} eV"
|
| 90 |
+
|
| 91 |
+
async def _arun(self, smiles: str) -> str:
|
| 92 |
+
"""Use the tool asynchronously."""
|
| 93 |
+
raise NotImplementedError()
|
| 94 |
+
|
tool/pdfreader.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Mon Dec 30 22:20:13 2024
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
from langchain.chains import LLMChain, SimpleSequentialChain, RetrievalQA, ConversationalRetrievalChain
|
| 8 |
+
|
| 9 |
+
from langchain import PromptTemplate
|
| 10 |
+
|
| 11 |
+
from langchain.tools import BaseTool
|
| 12 |
+
|
| 13 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 14 |
+
from langchain.base_language import BaseLanguageModel
|
| 15 |
+
from langchain.text_splitter import CharacterTextSplitter
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 19 |
+
from langchain_community.vectorstores import FAISS
|
| 20 |
+
from langchain_openai import ChatOpenAI
|
| 21 |
+
from langchain_openai import OpenAIEmbeddings
|
| 22 |
+
|
| 23 |
+
template = """
|
| 24 |
+
|
| 25 |
+
You are an expert chemist and your task is to respond to the question or
|
| 26 |
+
solve the problem to the best of your ability. You need to answer in as much detail as possible.
|
| 27 |
+
You can only respond with a single "Final Answer" format.
|
| 28 |
+
Use the following pieces of context to answer the question at the end.
|
| 29 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
| 30 |
+
<context>
|
| 31 |
+
{context}
|
| 32 |
+
</context>
|
| 33 |
+
|
| 34 |
+
Question: {question}
|
| 35 |
+
Answer:
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
class pdfreader(BaseTool):
|
| 40 |
+
name: str = "pdfreader"
|
| 41 |
+
description: str = (
|
| 42 |
+
|
| 43 |
+
"Used to read papers, summarize papers, Q&A based on papers, literature or publication"
|
| 44 |
+
"Input query , return the response"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
llm: BaseLanguageModel = None
|
| 48 |
+
path : str = None
|
| 49 |
+
return_direct: bool = True
|
| 50 |
+
def __init__(self, path: str = None):
|
| 51 |
+
super().__init__( )
|
| 52 |
+
self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
|
| 53 |
+
base_url="https://www.dmxapi.com/v1")
|
| 54 |
+
self.path = path
|
| 55 |
+
# api keys
|
| 56 |
+
|
| 57 |
+
def _run(self, query ) -> str:
|
| 58 |
+
|
| 59 |
+
loader = PyPDFLoader(self.path)
|
| 60 |
+
documents = loader.load()
|
| 61 |
+
|
| 62 |
+
text_splitter = CharacterTextSplitter(chunk_size=6000, chunk_overlap=1000)
|
| 63 |
+
docs = text_splitter.split_documents(documents)
|
| 64 |
+
embeddings = OpenAIEmbeddings(api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
|
| 65 |
+
base_url="https://www.dmxapi.com/v1")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
vectorstore = FAISS.from_documents(docs, embeddings)
|
| 69 |
+
prompt = PromptTemplate(template=template, input_variables=[ "question"])
|
| 70 |
+
qa_chain = RetrievalQA.from_chain_type(
|
| 71 |
+
llm= self.llm,
|
| 72 |
+
chain_type="stuff",
|
| 73 |
+
retriever=vectorstore.as_retriever(search_kwargs={"k": 2}),
|
| 74 |
+
return_source_documents=True,
|
| 75 |
+
chain_type_kwargs={"prompt": prompt},
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
result = qa_chain.invoke(query)
|
| 79 |
+
return result['result']
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
async def _arun(self, query) -> str:
|
| 83 |
+
"""Use the tool asynchronously."""
|
| 84 |
+
raise NotImplementedError("this tool does not support async")
|
| 85 |
+
|
| 86 |
+
|
tool/property.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Thu Sep 5 21:42:51 2024
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from langchain.tools import BaseTool
|
| 9 |
+
from rdkit import Chem
|
| 10 |
+
from rdkit.Chem import rdMolDescriptors
|
| 11 |
+
from rdkit.Chem import Descriptors
|
| 12 |
+
from utils import *
|
| 13 |
+
from rdkit.Chem import RDConfig
|
| 14 |
+
from rdkit.ML.Descriptors import MoleculeDescriptors
|
| 15 |
+
|
| 16 |
+
from rdkit.Contrib.SA_Score import sascorer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MolSimilarity(BaseTool):
|
| 20 |
+
name: str = "MolSimilarity"
|
| 21 |
+
description: str = (
|
| 22 |
+
"Input two molecule SMILES (separated by '.'), returns Tanimoto similarity."
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
def _run(self, smiles_pair: str) -> str:
|
| 29 |
+
smi_list = smiles_pair.split(".")
|
| 30 |
+
if len(smi_list) != 2:
|
| 31 |
+
return "Input error, please input two smiles strings separated by '.'"
|
| 32 |
+
else:
|
| 33 |
+
smiles1, smiles2 = smi_list
|
| 34 |
+
|
| 35 |
+
similarity = tanimoto(smiles1, smiles2)
|
| 36 |
+
|
| 37 |
+
if isinstance(similarity, str):
|
| 38 |
+
return similarity
|
| 39 |
+
|
| 40 |
+
if similarity == 1:
|
| 41 |
+
return "Error: Input Molecules Are Identical"
|
| 42 |
+
else:
|
| 43 |
+
|
| 44 |
+
message = f"The Tanimoto similarity between {smiles1} and {smiles2} is {round(similarity, 4)}"
|
| 45 |
+
return message
|
| 46 |
+
|
| 47 |
+
async def _arun(self, smiles_pair: str) -> str:
|
| 48 |
+
"""Use the tool asynchronously."""
|
| 49 |
+
raise NotImplementedError()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class SMILES2Weight(BaseTool):
|
| 53 |
+
name: str = "SMILES2Weight"
|
| 54 |
+
description: str = "Input SMILES, returns molecular weight."
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
def _run(self, smiles: str) -> str:
|
| 62 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 63 |
+
if mol is None:
|
| 64 |
+
return "Invalid SMILES string"
|
| 65 |
+
mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
|
| 66 |
+
return mol_weight
|
| 67 |
+
|
| 68 |
+
async def _arun(self, smiles: str) -> str:
|
| 69 |
+
"""Use the tool asynchronously."""
|
| 70 |
+
raise NotImplementedError()
|
| 71 |
+
|
| 72 |
+
class SMILES2LogP(BaseTool):
|
| 73 |
+
name: str = "SMILES2LogP"
|
| 74 |
+
description: str = "Input SMILES, returns molecular LogP."
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
def _run(self, smiles: str) -> str:
|
| 82 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 83 |
+
if mol is None:
|
| 84 |
+
return "Invalid SMILES string"
|
| 85 |
+
MolLogP = Descriptors.MolLogP(mol)
|
| 86 |
+
return MolLogP
|
| 87 |
+
|
| 88 |
+
async def _arun(self, smiles: str) -> str:
|
| 89 |
+
"""Use the tool asynchronously."""
|
| 90 |
+
raise NotImplementedError()
|
| 91 |
+
|
| 92 |
+
class SMILES2SAScore(BaseTool):
|
| 93 |
+
name: str = "SMILES2SAScore"
|
| 94 |
+
description: str = "Input SMILES, returns synthetic accessibility score to evaluate the difficulty of molecular synthesis."
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
):
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
def _run(self, smiles: str) -> str:
|
| 102 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 103 |
+
if mol is None:
|
| 104 |
+
return "Invalid SMILES string"
|
| 105 |
+
SAScore = sascorer.calculateScore(mol)
|
| 106 |
+
return f"This SAScore of the molecule is {SAScore}."
|
| 107 |
+
|
| 108 |
+
async def _arun(self, smiles: str) -> str:
|
| 109 |
+
"""Use the tool asynchronously."""
|
| 110 |
+
raise NotImplementedError()
|
| 111 |
+
|
| 112 |
+
class SMILES2Properties(BaseTool):
|
| 113 |
+
name: str = "SMILES2Properties"
|
| 114 |
+
description: str = "Input SMILES, returns basic physical and chemical properties."
|
| 115 |
+
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
|
| 121 |
+
def _run(self, smiles: str) -> str:
|
| 122 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 123 |
+
if mol is None:
|
| 124 |
+
return "Invalid SMILES string"
|
| 125 |
+
SAScore = sascorer.calculateScore(mol)
|
| 126 |
+
des_list = ['MolWt','NOCount', 'NumHAcceptors', 'NumHDonors', 'MolLogP', 'NumRotatableBonds','RingCount','NumAromaticRings','TPSA']
|
| 127 |
+
calculator = MoleculeDescriptors.MolecularDescriptorCalculator(des_list)
|
| 128 |
+
results = calculator.CalcDescriptors(mol)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
return f"SAScore: {'{:.2f}'.format(SAScore)}; molecular weight: {'{:.2f}'.format(results[0])}; number of Nitrogens and Oxygens: {results[1]}; number of Hydrogen Bond Acceptors: {results[2]}; number of Hydrogen Bond Donors:{results[3]}; LogP:{'{:.2f}'.format(results[4])}; number of Rotatable Bonds: {results[5]}; Ring count: {results[6]}; number of aromatic rings: {results[7]}; TPSA: {'{:.2f}'.format(results[8])}."
|
| 132 |
+
|
| 133 |
+
async def _arun(self, smiles: str) -> str:
|
| 134 |
+
"""Use the tool asynchronously."""
|
| 135 |
+
raise NotImplementedError()
|
| 136 |
+
|
| 137 |
+
class FuncGroups(BaseTool):
|
| 138 |
+
name: str = "FunctionalGroups"
|
| 139 |
+
description: str = "Input SMILES, return list of functional groups in the molecule."
|
| 140 |
+
dict_fgs: dict = None
|
| 141 |
+
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
|
| 147 |
+
# List obtained from https://github.com/rdkit/rdkit/blob/master/Data/FunctionalGroups.txt
|
| 148 |
+
self.dict_fgs = {
|
| 149 |
+
"furan": "o1cccc1",
|
| 150 |
+
"aldehydes": " [CX3H1](=O)[#6]",
|
| 151 |
+
"esters": " [#6][CX3](=O)[OX2H0][#6]",
|
| 152 |
+
"ketones": " [#6][CX3](=O)[#6]",
|
| 153 |
+
"amides": " C(=O)-N",
|
| 154 |
+
"thiol groups": " [SH]",
|
| 155 |
+
"alcohol groups": " [OH]",
|
| 156 |
+
"methylamide": "*-[N;D2]-[C;D3](=O)-[C;D1;H3]",
|
| 157 |
+
"carboxylic acids": "*-C(=O)[O;D1]",
|
| 158 |
+
"carbonyl methylester": "*-C(=O)[O;D2]-[C;D1;H3]",
|
| 159 |
+
"terminal aldehyde": "*-C(=O)-[C;D1]",
|
| 160 |
+
"amide": "*-C(=O)-[N;D1]",
|
| 161 |
+
"carbonyl methyl": "*-C(=O)-[C;D1;H3]",
|
| 162 |
+
"isocyanate": "*-[N;D2]=[C;D2]=[O;D1]",
|
| 163 |
+
"isothiocyanate": "*-[N;D2]=[C;D2]=[S;D1]",
|
| 164 |
+
"nitro": "*-[N;D3](=[O;D1])[O;D1]",
|
| 165 |
+
"nitroso": "*-[N;R0]=[O;D1]",
|
| 166 |
+
"oximes": "*=[N;R0]-[O;D1]",
|
| 167 |
+
"Imines": "*-[N;R0]=[C;D1;H2]",
|
| 168 |
+
"terminal azo": "*-[N;D2]=[N;D2]-[C;D1;H3]",
|
| 169 |
+
"hydrazines": "*-[N;D2]=[N;D1]",
|
| 170 |
+
"diazo": "*-[N;D2]#[N;D1]",
|
| 171 |
+
"cyano": "*-[C;D2]#[N;D1]",
|
| 172 |
+
"primary sulfonamide": "*-[S;D4](=[O;D1])(=[O;D1])-[N;D1]",
|
| 173 |
+
"methyl sulfonamide": "*-[N;D2]-[S;D4](=[O;D1])(=[O;D1])-[C;D1;H3]",
|
| 174 |
+
"sulfonic acid": "*-[S;D4](=O)(=O)-[O;D1]",
|
| 175 |
+
"methyl ester sulfonyl": "*-[S;D4](=O)(=O)-[O;D2]-[C;D1;H3]",
|
| 176 |
+
"methyl sulfonyl": "*-[S;D4](=O)(=O)-[C;D1;H3]",
|
| 177 |
+
"sulfonyl chloride": "*-[S;D4](=O)(=O)-[Cl]",
|
| 178 |
+
"methyl sulfinyl": "*-[S;D3](=O)-[C;D1]",
|
| 179 |
+
"methyl thio": "*-[S;D2]-[C;D1;H3]",
|
| 180 |
+
"thiols": "*-[S;D1]",
|
| 181 |
+
"thio carbonyls": "*=[S;D1]",
|
| 182 |
+
"halogens": "*-[#9,#17,#35,#53]",
|
| 183 |
+
"t-butyl": "*-[C;D4]([C;D1])([C;D1])-[C;D1]",
|
| 184 |
+
"tri fluoromethyl": "*-[C;D4](F)(F)F",
|
| 185 |
+
"acetylenes": "*-[C;D2]#[C;D1;H]",
|
| 186 |
+
"cyclopropyl": "*-[C;D3]1-[C;D2]-[C;D2]1",
|
| 187 |
+
"ethoxy": "*-[O;D2]-[C;D2]-[C;D1;H3]",
|
| 188 |
+
"methoxy": "*-[O;D2]-[C;D1;H3]",
|
| 189 |
+
"side-chain hydroxyls": "*-[O;D1]",
|
| 190 |
+
"ketones": "*=[O;D1]",
|
| 191 |
+
"primary amines": "*-[N;D1]",
|
| 192 |
+
"nitriles": "*#[N;D1]",
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
def _is_fg_in_mol(self, mol, fg):
|
| 196 |
+
fgmol = Chem.MolFromSmarts(fg)
|
| 197 |
+
mol = Chem.MolFromSmiles(mol.strip())
|
| 198 |
+
return len(Chem.Mol.GetSubstructMatches(mol, fgmol, uniquify=True)) > 0
|
| 199 |
+
|
| 200 |
+
def _run(self, smiles: str) -> str:
|
| 201 |
+
"""
|
| 202 |
+
Input a molecule SMILES or name.
|
| 203 |
+
Returns a list of functional groups identified by their common name (in natural language).
|
| 204 |
+
"""
|
| 205 |
+
try:
|
| 206 |
+
fgs_in_molec = [
|
| 207 |
+
name
|
| 208 |
+
for name, fg in self.dict_fgs.items()
|
| 209 |
+
if self._is_fg_in_mol(smiles, fg)
|
| 210 |
+
]
|
| 211 |
+
if len(fgs_in_molec) > 1:
|
| 212 |
+
return f"This molecule contains {', '.join(fgs_in_molec[:-1])}, and {fgs_in_molec[-1]}."
|
| 213 |
+
else:
|
| 214 |
+
return f"This molecule contains {fgs_in_molec[0]}."
|
| 215 |
+
except:
|
| 216 |
+
return "Wrong argument. Please input a valid molecular SMILES."
|
| 217 |
+
|
| 218 |
+
async def _arun(self, smiles: str) -> str:
|
| 219 |
+
"""Use the tool asynchronously."""
|
| 220 |
+
raise NotImplementedError()
|
tool/rag.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Created on Sun Feb 2 20:31:22 2025
|
| 4 |
+
|
| 5 |
+
@author: BM109X32G-10GPU-02
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from langchain.tools import BaseTool
|
| 10 |
+
|
| 11 |
+
from langchain.prompts.chat import (
|
| 12 |
+
ChatPromptTemplate,
|
| 13 |
+
HumanMessagePromptTemplate,
|
| 14 |
+
SystemMessagePromptTemplate,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from langchain import PromptTemplate
|
| 18 |
+
from langchain import HuggingFacePipeline
|
| 19 |
+
|
| 20 |
+
from langchain.base_language import BaseLanguageModel
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
from langchain.chains import RetrievalQA
|
| 25 |
+
|
| 26 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 27 |
+
|
| 28 |
+
from langchain_openai import ChatOpenAI
|
| 29 |
+
from langchain_community.vectorstores import FAISS
|
| 30 |
+
from torch import cuda, bfloat16
|
| 31 |
+
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
|
| 32 |
+
from langchain_openai import OpenAIEmbeddings
|
| 33 |
+
embeddings = OpenAIEmbeddings(api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
|
| 34 |
+
base_url="https://www.dmxapi.com/v1")
|
| 35 |
+
vectorstore=FAISS.load_local(r"J:\libray\osc\tool\rag", embeddings,allow_dangerous_deserialization =True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
template = """
|
| 39 |
+
|
| 40 |
+
You are an expert chemist and your task is to respond to the question or
|
| 41 |
+
solve the problem to the best of your ability.You can only respond with a single "Final Answer" format.
|
| 42 |
+
You need to list the key points and explain them in detail and accurately
|
| 43 |
+
Use the following pieces of context to answer the question at the end.
|
| 44 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
| 45 |
+
<context>
|
| 46 |
+
{context}
|
| 47 |
+
</context>
|
| 48 |
+
|
| 49 |
+
Question: {question}
|
| 50 |
+
Answer:
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class rag(BaseTool):
|
| 56 |
+
name: str = "RAG"
|
| 57 |
+
description: str= (
|
| 58 |
+
"Useful to answer questions that require technical "
|
| 59 |
+
|
| 60 |
+
"Provide specialized knowledge information for solving Q&A questions"
|
| 61 |
+
"Input query , return the response"
|
| 62 |
+
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
llm: BaseLanguageModel = None
|
| 66 |
+
path : str = None
|
| 67 |
+
|
| 68 |
+
def __init__(self, path: str = None):
|
| 69 |
+
super().__init__( )
|
| 70 |
+
self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
|
| 71 |
+
base_url="https://www.dmxapi.com/v1")
|
| 72 |
+
self.path = path
|
| 73 |
+
# api keys
|
| 74 |
+
|
| 75 |
+
def _run(self, query ) -> str:
|
| 76 |
+
|
| 77 |
+
prompt = PromptTemplate(template=template, input_variables=[ "question"])
|
| 78 |
+
qa_chain = RetrievalQA.from_chain_type(
|
| 79 |
+
llm=self.llm,
|
| 80 |
+
chain_type="stuff",
|
| 81 |
+
retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
|
| 82 |
+
return_source_documents=False,
|
| 83 |
+
|
| 84 |
+
chain_type_kwargs={"prompt": prompt},
|
| 85 |
+
)
|
| 86 |
+
chat_history = []
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
result = qa_chain.invoke(query)
|
| 90 |
+
return result['result']
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
async def _arun(self, query) -> str:
|
| 94 |
+
"""Use the tool asynchronously."""
|
| 95 |
+
raise NotImplementedError("this tool does not support async")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
tool/rag/index.faiss
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50ab8fd7a0c8d9dd62ebb1592b542b2d2fb2730ce21b761d2b36e8b5087743cf
|
| 3 |
+
size 6942765
|
tool/rag/index.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:908a85200a6d6e140c7d112f5c1b6b73376b42b2f85c3b3786f2050d6f1070d9
|
| 3 |
+
size 5545448
|
tool/search.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import langchain
|
| 5 |
+
|
| 6 |
+
import paperqa
|
| 7 |
+
import paperscraper
|
| 8 |
+
from langchain_community.utilities import SerpAPIWrapper
|
| 9 |
+
from langchain.base_language import BaseLanguageModel
|
| 10 |
+
from langchain.tools import BaseTool
|
| 11 |
+
from langchain_openai import OpenAIEmbeddings
|
| 12 |
+
from pypdf.errors import PdfReadError
|
| 13 |
+
from rdkit import Chem, DataStructs
|
| 14 |
+
from rdkit.Chem import AllChem
|
| 15 |
+
|
| 16 |
+
def is_smiles(text):
|
| 17 |
+
try:
|
| 18 |
+
m = Chem.MolFromSmiles(text, sanitize=False)
|
| 19 |
+
if m is None:
|
| 20 |
+
return False
|
| 21 |
+
return True
|
| 22 |
+
except:
|
| 23 |
+
return False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def is_multiple_smiles(text):
|
| 28 |
+
if is_smiles(text):
|
| 29 |
+
return "." in text
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def split_smiles(text):
|
| 34 |
+
return text.split(".")
|
| 35 |
+
|
| 36 |
+
def paper_scrap(search: str, pdir: str = "query", semantic_scholar_api_key: str = None) -> dict:
|
| 37 |
+
try:
|
| 38 |
+
return paperscraper.search_papers(
|
| 39 |
+
search,
|
| 40 |
+
pdir=pdir,
|
| 41 |
+
semantic_scholar_api_key=semantic_scholar_api_key,
|
| 42 |
+
)
|
| 43 |
+
except KeyError:
|
| 44 |
+
return {}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def paper_search(llm, query, semantic_scholar_api_key=None):
|
| 48 |
+
prompt = langchain.prompts.PromptTemplate(
|
| 49 |
+
input_variables=["question"],
|
| 50 |
+
template="""
|
| 51 |
+
I would like to find scholarly papers to answer
|
| 52 |
+
this question: {question}. Your response must be at
|
| 53 |
+
most 10 words long.
|
| 54 |
+
'A search query that would bring up papers that can answer
|
| 55 |
+
this question would be: '""",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
|
| 59 |
+
if not os.path.isdir("./query"): # todo: move to ckpt
|
| 60 |
+
os.mkdir("query/")
|
| 61 |
+
search = query_chain.run(query)
|
| 62 |
+
print("\nSearch:", search)
|
| 63 |
+
papers = paper_scrap(search, pdir=f"query/{re.sub(' ', '', search)}", semantic_scholar_api_key=semantic_scholar_api_key)
|
| 64 |
+
return papers
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, semantic_scholar_api_key=None):
|
| 68 |
+
"""Useful to answer questions that require
|
| 69 |
+
technical knowledge. Ask a specific question."""
|
| 70 |
+
papers = paper_search(llm, query, semantic_scholar_api_key=semantic_scholar_api_key)
|
| 71 |
+
if len(papers) == 0:
|
| 72 |
+
return "Not enough papers found"
|
| 73 |
+
docs = paperqa.Docs(
|
| 74 |
+
llm=llm,
|
| 75 |
+
summary_llm=llm,
|
| 76 |
+
embeddings=OpenAIEmbeddings(openai_api_key=openai_api_key),
|
| 77 |
+
)
|
| 78 |
+
not_loaded = 0
|
| 79 |
+
for path, data in papers.items():
|
| 80 |
+
try:
|
| 81 |
+
docs.add(path, data["citation"])
|
| 82 |
+
except (ValueError, FileNotFoundError, PdfReadError):
|
| 83 |
+
not_loaded += 1
|
| 84 |
+
|
| 85 |
+
if not_loaded > 0:
|
| 86 |
+
print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}.")
|
| 87 |
+
else:
|
| 88 |
+
print(f"\nFound {len(papers.items())} papers and loaded all of them.")
|
| 89 |
+
|
| 90 |
+
answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
|
| 91 |
+
return answer
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Scholar2ResultLLM(BaseTool):
|
| 95 |
+
name : str = "LiteratureSearch"
|
| 96 |
+
description: str = (
|
| 97 |
+
"Useful to answer questions that require technical "
|
| 98 |
+
"knowledge. Ask a specific question."
|
| 99 |
+
)
|
| 100 |
+
llm: BaseLanguageModel = None
|
| 101 |
+
openai_api_key: str = None
|
| 102 |
+
semantic_scholar_api_key: str = None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def __init__(self, llm, openai_api_key, semantic_scholar_api_key):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.llm = llm
|
| 108 |
+
# api keys
|
| 109 |
+
self.openai_api_key = openai_api_key
|
| 110 |
+
self.semantic_scholar_api_key = semantic_scholar_api_key
|
| 111 |
+
|
| 112 |
+
def _run(self, query) -> str:
|
| 113 |
+
return scholar2result_llm(
|
| 114 |
+
self.llm,
|
| 115 |
+
query,
|
| 116 |
+
openai_api_key=self.openai_api_key,
|
| 117 |
+
semantic_scholar_api_key=self.semantic_scholar_api_key
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
async def _arun(self, query) -> str:
|
| 121 |
+
"""Use the tool asynchronously."""
|
| 122 |
+
raise NotImplementedError("this tool does not support async")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def web_search(keywords, search_engine="google"):
|
| 126 |
+
try:
|
| 127 |
+
return SerpAPIWrapper(
|
| 128 |
+
serpapi_api_key='3795acda6a74ea15033d34b54eac82982b26f559147d9cf04aca4bfca91c3e9d', search_engine=search_engine
|
| 129 |
+
).run(keywords)
|
| 130 |
+
except:
|
| 131 |
+
return "No results, try another search"
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class WebSearch(BaseTool):
|
| 135 |
+
name: str = "WebSearch"
|
| 136 |
+
description: str = (
|
| 137 |
+
"Input a specific question, returns an answer from web search. "
|
| 138 |
+
"Give more detailed information and use more general features to formulate your questions."
|
| 139 |
+
)
|
| 140 |
+
serp_api_key: str = None
|
| 141 |
+
|
| 142 |
+
def __init__(self, serp_api_key: str = None):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.serp_api_key = serp_api_key
|
| 145 |
+
|
| 146 |
+
def _run(self, query: str) -> str:
|
| 147 |
+
if not self.serp_api_key:
|
| 148 |
+
return (
|
| 149 |
+
"No SerpAPI key found. This tool may not be used without a SerpAPI key."
|
| 150 |
+
)
|
| 151 |
+
return web_search(query)
|
| 152 |
+
|
| 153 |
+
async def _arun(self, query: str) -> str:
|
| 154 |
+
raise NotImplementedError("Async not implemented")
|
| 155 |
+
|
| 156 |
+
|