jinysun commited on
Commit
64e9ead
·
verified ·
1 Parent(s): dbaa85f

Upload 46 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tool/dap/util/data/bindingdb_kd.tab filter=lfs diff=lfs merge=lfs -text
37
+ tool/dap/util/data/davis.tab filter=lfs diff=lfs merge=lfs -text
38
+ tool/rag/index.faiss filter=lfs diff=lfs merge=lfs -text
tool/ImageAnalysis.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sat Oct 26 15:35:19 2024
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+ from langchain_community.embeddings import OllamaEmbeddings
9
+ from langchain.tools import BaseTool
10
+ from langchain_openai import ChatOpenAI
11
+ from langchain_core.messages import HumanMessage, SystemMessage
12
+ from langchain.base_language import BaseLanguageModel
13
+ import base64
14
+ from io import BytesIO
15
+ from PIL import Image
16
+
17
+
18
+
19
+ def convert_to_base64(pil_image):
20
+ buffered = BytesIO()
21
+ pil_image.save(buffered, format="PNG")
22
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
23
+ return img_str
24
+
25
+
26
+ class Imageanalysis(BaseTool):
27
+ name: str = "Imageanalysis"
28
+ description: str = (
29
+ "Useful to answer questions according to the image, figure, diagram or graph. "
30
+ "Useful to analysis the information in the image, figure, diagram or graph. "
31
+ "Input query about image/figure/graph/diagram, return the response"
32
+ )
33
+ return_direct: bool = True
34
+ llm: BaseLanguageModel = None
35
+ path : str = None
36
+
37
+ def __init__(self, path):
38
+ super().__init__( )
39
+ self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
40
+ base_url="https://www.dmxapi.com/v1")
41
+ self.path = path
42
+ # api keys
43
+
44
+ def _run(self, query ) -> str:
45
+ try:
46
+ pil_image = Image.open(self.path)
47
+ rgb_im = pil_image.convert('RGB')
48
+ image_b64 = convert_to_base64(pil_image)
49
+ message = HumanMessage(
50
+ content=[
51
+ {"type": "text", "text": query},
52
+ {
53
+ "type": "image_url",
54
+ "image_url": {"url":f"data:image/jpeg;base64,{image_b64}"},
55
+ },
56
+ ],)
57
+ response = self.llm.invoke([message])
58
+ return response.content
59
+
60
+ except Exception as e:
61
+ return str(e)
62
+
63
+
64
+ async def _arun(self, query) -> str:
65
+ """Use the tool asynchronously."""
66
+ raise NotImplementedError("this tool does not support async")
67
+
68
+
tool/PCE.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Created on Wed Sep 11 10:27:20 2024
5
+
6
+ @author: BM109X32G-10GPU-02
7
+ """
8
+
9
+ from langchain.tools import BaseTool
10
+ from rdkit import Chem
11
+ from rdkit.Chem import rdMolDescriptors
12
+ from rdkit.Chem import Descriptors
13
+ from .deepacceptor import RF
14
+ from .deepdonor import sm, pm
15
+ from .dap import run, screen
16
+ import pandas as pd
17
+
18
+ class acceptor_predictor(BaseTool):
19
+ name:str = "acceptor_predictor"
20
+ description:str = (
21
+ "Input acceptor SMILES , returns the score of the acceptor."
22
+ )
23
+
24
+ def __init__(self):
25
+ super().__init__()
26
+ def _run(self, smiles: str) -> str:
27
+ mol = Chem.MolFromSmiles(smiles)
28
+ if mol is None:
29
+ return "Invalid SMILES string"
30
+ smiles = Chem.MolToSmiles(mol)
31
+ pce = RF.main( str(smiles) )
32
+ return f'The power conversion efficiency (PCE) is predicted to be {pce} (predicted by DeepAcceptor) '
33
+
34
+ async def _arun(self, smiles: str) -> str:
35
+ """Use the tool asynchronously."""
36
+ raise NotImplementedError()
37
+
38
+ class donor_predictor(BaseTool):
39
+ name:str = "donor_predictor"
40
+ description:str = (
41
+ "Input donor SMILES , returns the score of the donor."
42
+ )
43
+
44
+ def __init__(self):
45
+ super().__init__()
46
+ def _run(self, smiles: str) -> str:
47
+ mol = Chem.MolFromSmiles(smiles)
48
+ if mol is None:
49
+ return "Invalid SMILES string"
50
+
51
+ mol = Chem.MolFromSmiles(smiles)
52
+ if mol is None:
53
+ return "Invalid SMILES string"
54
+ sdpce = sm.main( str(smiles) )
55
+ pdpce = pm.main( str(smiles) )
56
+ return f'The power conversion efficiency (PCE) of the given molecule is predicted to be {sdpce} as a small molecule donor , and {pdpce} as a polymer donor(predicted by DeepDonor) '
57
+
58
+ async def _arun(self, smiles: str) -> str:
59
+ """Use the tool asynchronously."""
60
+ raise NotImplementedError()
61
+
62
+
63
+
64
+ class dap_predictor(BaseTool):
65
+ name:str = "dap_predictor"
66
+ description :str = (
67
+ "Input SMILES of D/A pairs(separated by '.') , returns the performance of the D/A pairs ."
68
+ )
69
+
70
+
71
+ def __init__(self):
72
+ super().__init__()
73
+
74
+ def _run(self, smiles_pair: str) -> str:
75
+ smi_list = smiles_pair.split(".")
76
+ if len(smi_list) != 2:
77
+
78
+ return "Input error, please input two smiles strings separated by '.'"
79
+
80
+ else:
81
+ smiles1, smiles2 = smi_list
82
+
83
+
84
+ pce = run.smiles_aas_test( str(smiles1 ), str(smiles2) )
85
+
86
+ return pce
87
+
88
+ async def _arun(self, smiles_pair: str) -> str:
89
+ """Use the tool asynchronously."""
90
+ raise NotImplementedError()
91
+
92
+
93
+
94
+ class dap_screen(BaseTool):
95
+ name:str = "dap_screen"
96
+ description :str = (
97
+ "Input dataset path containing D/A pairs, returns the files of prediction results."
98
+ )
99
+ return_direct: bool = True
100
+ def __init__(self):
101
+ super().__init__()
102
+
103
+ def _run(self, file_path: str) -> str:
104
+ smi_list = screen.smiles_aas_test(file_path)
105
+
106
+ return smi_list
107
+
108
+ async def _arun(self, smiles_pair: str) -> str:
109
+ """Use the tool asynchronously."""
110
+ raise NotImplementedError()
111
+
112
+
113
+ from .comget import generator
114
+
115
+ class molgen(BaseTool):
116
+ name: str = "donorgen"
117
+ description: str = (
118
+
119
+ "Useful to generate polymer donor molecules with required PCE. "
120
+ "Input the values of PCE , return the SMILES"
121
+ )
122
+
123
+
124
+ def __init__(self
125
+ ):
126
+ super().__init__( )
127
+
128
+
129
+ def _run(self, value ) -> str:
130
+ try:
131
+ results = generator.generation(value)
132
+ for i in results['smiles']:
133
+ pdpce = pm.main( str(i) )
134
+ if abs(pdpce-float(value))<1.0:
135
+ return f"The SMILES of generated donor is {i}, its predicted PCE is {pdpce}."
136
+ break
137
+
138
+
139
+
140
+ except Exception as e:
141
+ return str(e)
142
+
143
+
144
+ async def _arun(self, query) -> str:
145
+ """Use the tool asynchronously."""
146
+ raise NotImplementedError("this tool does not support async")
tool/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """load all tools."""
2
+
3
+ from .coder import *
4
+ from .property import *
5
+ from .search import *
6
+ from .PCE import *
7
+ from .converters import *
8
+ from .orbital import *
9
+ from .graphconverter import *
10
+ from .ImageAnalysis import *
11
+ from .pdfreader import *
12
+ from .rag import *
13
+ from .browsersearch import *
tool/browsersearch.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from browser_use import Agent
3
+ import asyncio
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+ from langchain.tools import BaseTool
7
+
8
+ async def main(task):
9
+ agent = Agent(
10
+ task=task,
11
+ llm=ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
12
+ base_url="https://www.dmxapi.com/v1"),
13
+ )
14
+ result = await agent.run()
15
+ return result
16
+
17
+ class browseruse(BaseTool):
18
+ name: str = "browseruse"
19
+ description: str = ("Calling the browser to search for information in specific website"
20
+ "input query, return the searching results")
21
+
22
+ def __init__(
23
+ self,
24
+ ):
25
+ super().__init__()
26
+
27
+ def _run(self, task: str) -> str:
28
+ result = asyncio.run(main(task))
29
+ return result
30
+
31
+ async def _arun(self, smiles: str) -> str:
32
+ """Use the tool asynchronously."""
33
+ raise NotImplementedError()
34
+
35
+
36
+
tool/chemspace.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import molbloom
4
+ import pandas as pd
5
+ import requests
6
+ from langchain.tools import BaseTool
7
+
8
+ from utils import is_smiles
9
+
10
+
11
+ class ChemSpace:
12
+ def __init__(self, chemspace_api_key=None):
13
+ self.chemspace_api_key = chemspace_api_key
14
+ self._renew_token() # Create token
15
+
16
+ def _renew_token(self):
17
+ self.chemspace_token = requests.get(
18
+ url="https://api.chem-space.com/auth/token",
19
+ headers={
20
+ "Accept": "application/json",
21
+ "Authorization": f"Bearer {self.chemspace_api_key}",
22
+ },
23
+ ).json()["access_token"]
24
+
25
+ def _make_api_request(
26
+ self,
27
+ query,
28
+ request_type,
29
+ count,
30
+ categories,
31
+ ):
32
+ """
33
+ Make a generic request to chem-space API.
34
+
35
+ Categories request.
36
+ CSCS: Custom Request: Could be useful for requesting whole synthesis
37
+ CSMB: Make-On-Demand Building Blocks
38
+ CSSB: In-Stock Building Blocks
39
+ CSSS: In-stock Screening Compounds
40
+ CSMS: Make-On-Demand Screening Compounds
41
+ """
42
+
43
+ def _do_request():
44
+ data = requests.request(
45
+ "POST",
46
+ url=f"https://api.chem-space.com/v3/search/{request_type}?count={count}&page=1&categories={categories}",
47
+ headers={
48
+ "Accept": "application/json; version=3.1",
49
+ "Authorization": f"Bearer {self.chemspace_token}",
50
+ },
51
+ data={"SMILES": f"{query}"},
52
+ ).json()
53
+ return data
54
+
55
+ data = _do_request()
56
+
57
+ # renew token if token is invalid
58
+ if "message" in data.keys():
59
+ if data["message"] == "Your request was made with invalid credentials.":
60
+ self._renew_token()
61
+
62
+ data = _do_request()
63
+ return data
64
+
65
+ def _convert_single(self, query, search_type: str):
66
+ """Do query for a single molecule"""
67
+ data = self._make_api_request(query, "exact", 1, "CSCS,CSMB,CSSB")
68
+ if data["count"] > 0:
69
+ return data["items"][0][search_type]
70
+ else:
71
+ return "No data was found for this compound."
72
+
73
+ def convert_mol_rep(self, query, search_type: str = "smiles"):
74
+ if ", " in query:
75
+ query_list = query.split(", ")
76
+ else:
77
+ query_list = [query]
78
+ smi = ""
79
+ try:
80
+ for q in query_list:
81
+ smi += f"{query}'s {search_type} is: {str(self._convert_single(q, search_type))}"
82
+ return smi
83
+ except Exception:
84
+ return "The input provided is wrong. Input either a single molecule, or multiple molecules separated by a ', '"
85
+
86
+ def buy_mol(
87
+ self,
88
+ smiles,
89
+ request_type="exact",
90
+ count=1,
91
+ ):
92
+ """
93
+ Get data about purchasing compounds.
94
+
95
+ smiles: smiles string of the molecule you want to buy
96
+ request_type: one of "exact", "sim" (search by similarity), "sub" (search by substructure).
97
+ count: retrieve data for this many substances max.
98
+ """
99
+
100
+ def purchasable_check(
101
+ s,
102
+ ):
103
+ if not is_smiles(s):
104
+ try:
105
+ s = self.convert_mol_rep(s, "smiles")
106
+ except:
107
+ return "Invalid SMILES string."
108
+
109
+ """Checks if molecule is available for purchase (ZINC20)"""
110
+ try:
111
+ r = molbloom.buy(s, canonicalize=True)
112
+ except:
113
+ print("invalid smiles")
114
+ return False
115
+ if r:
116
+ return True
117
+ else:
118
+ return False
119
+
120
+ purchasable = purchasable_check(smiles)
121
+
122
+ if request_type == "exact":
123
+ categories = "CSMB,CSSB"
124
+ elif request_type in ["sim", "sub"]:
125
+ categories = "CSSS,CSMS"
126
+
127
+ data = self._make_api_request(smiles, request_type, count, categories)
128
+
129
+ try:
130
+ if data["count"] == 0:
131
+ if purchasable:
132
+ return "Compound is purchasable, but price is unknown."
133
+ else:
134
+ return "Compound is not purchasable."
135
+ except KeyError:
136
+ return "Invalid query, try something else. "
137
+
138
+ print(f"Obtaining data for {data['count']} substances.")
139
+
140
+ dfs = []
141
+ # Convert this data into df
142
+ for item in data["items"]:
143
+ dfs_tmp = []
144
+ smiles = item["smiles"]
145
+ offers = item["offers"]
146
+
147
+ for off in offers:
148
+ df_tmp = pd.DataFrame(off["prices"])
149
+ df_tmp["vendorName"] = off["vendorName"]
150
+ df_tmp["time"] = off["shipsWithin"]
151
+ df_tmp["purity"] = off["purity"]
152
+
153
+ dfs_tmp.append(df_tmp)
154
+
155
+ df_this = pd.concat(dfs_tmp)
156
+ df_this["smiles"] = smiles
157
+ dfs.append(df_this)
158
+
159
+ df = pd.concat(dfs).reset_index(drop=True)
160
+
161
+ df["quantity"] = df["pack"].astype(str) + df["uom"]
162
+ df["time"] = df["time"].astype(str) + " days"
163
+
164
+ df = df.drop(columns=["pack", "uom"])
165
+ # Remove all entries that are not numbers
166
+ df = df[df["priceUsd"].astype(str).str.isnumeric()]
167
+
168
+ cheapest = df.iloc[df["priceUsd"].astype(float).idxmin()]
169
+ return f"{cheapest['quantity']} of this molecule cost {cheapest['priceUsd']} USD and can be purchased at {cheapest['vendorName']}."
170
+
171
+
172
+ class GetMoleculePrice(BaseTool):
173
+ name :str = "GetMoleculePrice"
174
+ description :str = "Get the cheapest available price of a molecule."
175
+ chemspace_api_key: str = None
176
+ url: str = None
177
+
178
+ def __init__(self, chemspace_api_key: str = None):
179
+ super().__init__()
180
+ self.chemspace_api_key = chemspace_api_key
181
+ self.url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}"
182
+
183
+ def _run(self, query: str) -> str:
184
+ if not self.chemspace_api_key:
185
+ return "No Chemspace API key found. This tool may not be used without a Chemspace API key."
186
+ try:
187
+ chemspace = ChemSpace(self.chemspace_api_key)
188
+ price = chemspace.buy_mol(query)
189
+ return price
190
+ except Exception as e:
191
+ return str(e)
192
+
193
+ async def _arun(self, query: str) -> str:
194
+ """Use the tool asynchronously."""
195
+ raise NotImplementedError()
tool/coder.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sat Oct 26 15:35:19 2024
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+ from langchain_community.embeddings import OllamaEmbeddings
9
+ from langchain.tools import BaseTool
10
+ from langchain_openai import ChatOpenAI
11
+ from langchain_core.messages import HumanMessage, SystemMessage
12
+ from langchain.base_language import BaseLanguageModel
13
+
14
+
15
+ class codewriter(BaseTool):
16
+ name:str = "codewriter"
17
+ description:str = (
18
+ "Useful to answer questions that require writing codes "
19
+ "return the usage and instruction of codes"
20
+ )
21
+
22
+ llm: BaseLanguageModel = None
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
26
+ base_url="https://www.dmxapi.com/v1")
27
+ # api keys
28
+
29
+ def _run(self, query) -> str:
30
+ messages = [
31
+ SystemMessage(content="You are an expert at writing code, write the corresponding code based on the inputs"),
32
+ HumanMessage(content=query),
33
+ ]
34
+
35
+ response = self.llm.invoke(messages)
36
+ return response
37
+
38
+ async def _arun(self, query) -> str:
39
+ """Use the tool asynchronously."""
40
+ raise NotImplementedError("this tool does not support async")
41
+
42
+
43
+
tool/comget/dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from utils import SmilesEnumerator
4
+ import numpy as np
5
+ import re
6
+
7
+ class SmileDataset(Dataset):
8
+
9
+ def __init__(self, args, data, content, block_size, aug_prob = 0.5, prop = None, scaffold = None, scaffold_maxlen = None):
10
+ chars = sorted(list(set(content)))
11
+ data_size, vocab_size = len(data), len(chars)
12
+ print('data has %d smiles, %d unique characters.' % (data_size, vocab_size))
13
+
14
+ self.stoi = { ch:i for i,ch in enumerate(chars) }
15
+ self.itos = { i:ch for i,ch in enumerate(chars) }
16
+ self.max_len = block_size
17
+ self.vocab_size = vocab_size
18
+ self.data = data
19
+ self.prop = prop
20
+ self.sca = scaffold
21
+ self.scaf_max_len = scaffold_maxlen
22
+ self.debug = args.debug
23
+ self.tfm = SmilesEnumerator()
24
+ self.aug_prob = aug_prob
25
+
26
+ def __len__(self):
27
+ if self.debug:
28
+ return math.ceil(len(self.data) / (self.max_len + 1))
29
+ else:
30
+ return len(self.data)
31
+
32
+ def __getitem__(self, idx):
33
+ smiles, prop, scaffold = self.data[idx], self.prop[idx], self.sca[idx] # self.prop.iloc[idx, :].values --> if multiple properties
34
+ smiles = smiles.strip()
35
+ scaffold = scaffold.strip()
36
+
37
+ p = np.random.uniform()
38
+ if p < self.aug_prob:
39
+ smiles = self.tfm.randomize_smiles(smiles)
40
+
41
+ pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
42
+ regex = re.compile(pattern)
43
+ smiles += str('<')*(self.max_len - len(regex.findall(smiles)))
44
+
45
+ if len(regex.findall(smiles)) > self.max_len:
46
+ smiles = smiles[:self.max_len]
47
+
48
+ smiles=regex.findall(smiles)
49
+
50
+ scaffold += str('<')*(self.scaf_max_len - len(regex.findall(scaffold)))
51
+
52
+ if len(regex.findall(scaffold)) > self.scaf_max_len:
53
+ scaffold = scaffold[:self.scaf_max_len]
54
+
55
+ scaffold=regex.findall(scaffold)
56
+
57
+ dix = [self.stoi[s] for s in smiles]
58
+ sca_dix = [self.stoi[s] for s in scaffold]
59
+
60
+ sca_tensor = torch.tensor(sca_dix, dtype=torch.long)
61
+ x = torch.tensor(dix[:-1], dtype=torch.long)
62
+ y = torch.tensor(dix[1:], dtype=torch.long)
63
+ # prop = torch.tensor([prop], dtype=torch.long)
64
+ prop = torch.tensor([prop], dtype = torch.float)
65
+ return x, y, prop, sca_tensor
tool/comget/generator.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import pandas as pd
3
+
4
+ import math
5
+ from tqdm import tqdm
6
+ import argparse
7
+ from .model import GPT, GPTConfig
8
+ import pandas as pd
9
+ import torch
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ #import seaborn as sns
13
+ from .moses.utils import get_mol
14
+ import re
15
+
16
+ import json
17
+ from rdkit.Chem import RDConfig
18
+
19
+ import selfies as sf
20
+ import os
21
+ import sys
22
+ sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
23
+ from .utils import sample, canonic_smiles
24
+ import sascorer
25
+ from rdkit import Chem
26
+ from rdkit.Chem.rdMolDescriptors import CalcTPSA
27
+ import os
28
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
29
+
30
+ def get_selfie_and_smiles_encodings_for_dataset(smiles):
31
+ """
32
+ Returns encoding, alphabet and length of largest molecule in SMILES and
33
+ SELFIES, given a file containing SMILES molecules.
34
+
35
+ input:
36
+ csv file with molecules. Column's name must be 'smiles'.
37
+ output:
38
+ - selfies encoding
39
+ - selfies alphabet
40
+ - longest selfies string
41
+ - smiles encoding (equivalent to file content)
42
+ - smiles alphabet (character based)
43
+ - longest smiles string
44
+ """
45
+
46
+ smiles_list = np.asanyarray(smiles)
47
+
48
+ smiles_alphabet = list(set("".join(smiles_list)))
49
+ smiles_alphabet.append(" ") # for padding
50
+
51
+ largest_smiles_len = len(max(smiles_list, key=len))
52
+
53
+ print("--> Translating SMILES to SELFIES...")
54
+ selfies_list = list(map(sf.encoder, smiles_list))
55
+
56
+ all_selfies_symbols = sf.get_alphabet_from_selfies(selfies_list)
57
+ all_selfies_symbols.add("[nop]")
58
+ selfies_alphabet = list(all_selfies_symbols)
59
+
60
+ largest_selfies_len = max(sf.len_selfies(s) for s in selfies_list)
61
+
62
+ print("Finished translating SMILES to SELFIES.")
63
+
64
+ return selfies_list, selfies_alphabet, largest_selfies_len, \
65
+ smiles_list, smiles_alphabet, largest_smiles_len
66
+
67
+ def generation(value):
68
+ parser = argparse.ArgumentParser()
69
+ #parser.add_argument('--model_weight', type=str, help="path of model weights", required=True)
70
+ parser.add_argument('--scaffold', action='store_true', default=False, help='condition on scaffold')
71
+ parser.add_argument('--lstm', action='store_true', default=False, help='use lstm for transforming scaffold')
72
+ #parser.add_argument('--csv_name', type=str, help="name to save the generated mols in csv format", required=True)
73
+ parser.add_argument('--data_name', type=str, default = 'moses2', help="name of the dataset to train on", required=False)
74
+ parser.add_argument('--batch_size', type=int, default = 512, help="batch size", required=False)
75
+ parser.add_argument('--gen_size', type=int, default = 10000, help="number of times to generate from a batch", required=False)
76
+ parser.add_argument('--vocab_size', type=int, default = 26, help="number of layers", required=False) # previously 28 .... 26 for moses. 94 for guacamol
77
+ parser.add_argument('--block_size', type=int, default = 54, help="number of layers", required=False) # previously 57... 54 for moses. 100 for guacamol.
78
+ # parser.add_argument('--num_props', type=int, default = 0, help="number of properties to use for condition", required=False)
79
+ parser.add_argument('--props', nargs="+", default = [], help="properties to be used for condition", required=False)
80
+ parser.add_argument('--n_layer', type=int, default = 8, help="number of layers", required=False)
81
+ parser.add_argument('--n_head', type=int, default = 8, help="number of heads", required=False)
82
+ parser.add_argument('--n_embd', type=int, default = 256, help="embedding dimension", required=False)
83
+ parser.add_argument('--lstm_layers', type=int, default = 2, help="number of layers in lstm", required=False)
84
+
85
+ args = parser.parse_args()
86
+ args.data_name = 'ppcenos'
87
+ args.vocab_size = 29 #
88
+ args.block_size = 196 #max_len
89
+ args.gen_size = 20
90
+ args.batch_size = 5
91
+ args.csv_name = 'ppcenos'
92
+ args.props = ['pce']
93
+ context = "[C]"
94
+ args.scaffold = False
95
+
96
+
97
+ pattern = "(\[[^\]]+]|<|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
98
+ regex = re.compile(pattern)
99
+
100
+ if ('moses' in args.data_name) and args.scaffold:
101
+ scaffold_max_len=48
102
+ elif ('guacamol' in args.data_name):
103
+ scaffold_max_len = 107
104
+ else:
105
+ scaffold_max_len = 181
106
+
107
+
108
+ stoi = json.load(open('tool/comget/' + f'{args.data_name}.json', 'r'))
109
+
110
+ # itos = { i:ch for i,ch in enumerate(chars) }
111
+ itos = { i:ch for ch,i in stoi.items() }
112
+
113
+
114
+ print(len(itos))
115
+
116
+
117
+ num_props = len(args.props)
118
+ mconf = GPTConfig(args.vocab_size, args.block_size, num_props = num_props,
119
+ n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, scaffold = args.scaffold, scaffold_maxlen = scaffold_max_len,
120
+ lstm = args.lstm, lstm_layers = args.lstm_layers)
121
+ model = GPT(mconf)
122
+
123
+ args.model_weight = f'{args.csv_name}.pt'
124
+ model.load_state_dict(torch.load('tool/comget/' + args.model_weight))
125
+ model.to('cuda')
126
+ print('Model loaded')
127
+
128
+ gen_iter = math.ceil(args.gen_size / args.batch_size)
129
+ # gen_iter = 2
130
+
131
+ if 'guacamol1' in args.data_name:
132
+ prop2value = {'qed': [0.3, 0.5, 0.7], 'sas': [2.0, 3.0, 4.0], 'logp': [2.0, 4.0, 6.0], 'tpsa': [40.0, 80.0, 120.0],
133
+ 'tpsa_logp': [[40.0, 2.0], [80.0, 2.0], [120.0, 2.0], [40.0, 4.0], [80.0, 4.0], [120.0, 4.0], [40.0, 6.0], [80.0, 6.0], [120.0, 6.0]],
134
+ 'sas_logp': [[2.0, 2.0], [2.0, 4.0], [2.0, 6.0], [3.0, 2.0], [3.0, 4.0], [3.0, 6.0], [4.0, 2.0], [4.0, 4.0], [4.0, 6.0]],
135
+ 'tpsa_sas': [[40.0, 2.0], [80.0, 2.0], [120.0, 2.0], [40.0, 3.0], [80.0, 3.0], [120.0, 3.0], [40.0, 4.0], [80.0, 4.0], [120.0, 4.0]],
136
+ 'tpsa_logp_sas': [[40.0, 2.0, 2.0], [40.0, 2.0, 4.0], [40.0, 6.0, 4.0], [40.0, 6.0, 2.0], [80.0, 6.0, 4.0], [80.0, 2.0, 4.0], [80.0, 2.0, 2.0], [80.0, 6.0, 2.0]]}
137
+ else:
138
+ prop2value = { 'pce': [float(value)]}
139
+
140
+
141
+ prop_condition = None
142
+ if len(args.props) > 0:
143
+ prop_condition = prop2value['_'.join(args.props)]
144
+
145
+ scaf_condition = None
146
+
147
+
148
+ all_dfs = []
149
+ all_metrics = []
150
+
151
+
152
+ count = 0
153
+
154
+ if prop_condition is not None and scaf_condition is None :
155
+
156
+ for c in prop_condition:
157
+ molecules = []
158
+ selfies = []
159
+ count += 1
160
+ for i in tqdm(range(gen_iter)):
161
+ x = torch.tensor([stoi[s] for s in regex.findall(context)], dtype=torch.long)[None,...].repeat(args.batch_size, 1).to('cuda')
162
+ p = None
163
+ if len(args.props) == 1:
164
+ p = torch.tensor([c]).repeat(args.batch_size, 1).to('cuda') # for single condition
165
+ else:
166
+ p = torch.tensor([c]).repeat(args.batch_size, 1).unsqueeze(1).to('cuda') # for multiple conditions
167
+ sca = None
168
+ y = sample(model, x, 300, temperature= 1.0, sample=True, top_k = 10, prop = p, scaffold = sca)
169
+ for gen_mol in y:
170
+ completion = ''.join([itos[int(i)] for i in gen_mol])
171
+ completion = completion.replace('<', '')
172
+ selfies.append(completion)
173
+ file = pd.DataFrame(selfies)
174
+
175
+ for ind, i in enumerate( file[0]):
176
+
177
+ smi = (sf.decoder(eval(repr(i))))
178
+ mol = get_mol(smi)
179
+ # gen_smiles.append(completion)
180
+
181
+ if mol:
182
+
183
+ molecules.append(mol)
184
+ else:
185
+ print(ind)
186
+ print(i)
187
+
188
+ "Valid molecules % = {}".format(len(molecules))
189
+
190
+ mol_dict = []
191
+
192
+
193
+ for i in molecules:
194
+ mol_dict.append({'molecule' : i, 'smiles': Chem.MolToSmiles(i)})
195
+
196
+ # for i in gen_smiles:
197
+ # mol_dict.append({'temperature' : temp, 'smiles': i})
198
+
199
+
200
+ results = pd.DataFrame(mol_dict)
201
+
202
+ all_dfs.append(results)
203
+
204
+ results = pd.concat(all_dfs)
205
+
206
+ return results
tool/comget/model.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model:
3
+ - the initial stem consists of a combination of token encoding and a positional encoding
4
+ - the meat of it is a uniform sequence of Transformer blocks
5
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
6
+ - all blocks feed into a central residual pathway similar to resnets
7
+ - the final decoder is a linear projection into a vanilla Softmax classifier
8
+ """
9
+
10
+ import math
11
+ import logging
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.nn import functional as F
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class GPTConfig:
20
+ """ base GPT config, params common to all GPT versions """
21
+ embd_pdrop = 0.1
22
+ resid_pdrop = 0.1
23
+ attn_pdrop = 0.1
24
+
25
+ def __init__(self, vocab_size, block_size, **kwargs):
26
+ self.vocab_size = vocab_size
27
+ self.block_size = block_size
28
+ for k,v in kwargs.items():
29
+ setattr(self, k, v)
30
+
31
+ class GPT1Config(GPTConfig):
32
+ """ GPT-1 like network roughly 125M params """
33
+ n_layer = 12
34
+ n_head = 12
35
+ n_embd = 768
36
+
37
+ class RMSNorm(nn.Module):
38
+ """Root Mean Square Layer Normalization.
39
+
40
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
41
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
42
+ """
43
+
44
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
45
+ super().__init__()
46
+ self.scale = nn.Parameter(torch.ones(size))
47
+ self.eps = eps
48
+ self.dim = dim
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ # NOTE: the original RMSNorm paper implementation is not equivalent
52
+ # norm_x = x.norm(2, dim=self.dim, keepdim=True)
53
+ # rms_x = norm_x * d_x ** (-1. / 2)
54
+ # x_normed = x / (rms_x + self.eps)
55
+ # keep RMSNorm in float32
56
+ norm_x = x.to(torch.float32).pow(2).mean(dim=self.dim, keepdim=True)
57
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
58
+ return (self.scale * x_normed).type_as(x)
59
+
60
+ class CausalSelfAttention(nn.Module):
61
+ """
62
+ A vanilla multi-head masked self-attention layer with a projection at the end.
63
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
64
+ explicit implementation here to show that there is nothing too scary here.
65
+ """
66
+
67
+ def __init__(self, config):
68
+ super().__init__()
69
+ assert config.n_embd % config.n_head == 0
70
+ # key, query, value projections for all heads
71
+ self.key = nn.Linear(config.n_embd, config.n_embd)
72
+ self.query = nn.Linear(config.n_embd, config.n_embd)
73
+ self.value = nn.Linear(config.n_embd, config.n_embd)
74
+ self.q_proj = nn.Linear(
75
+ config.n_embd ,
76
+ config.n_embd ,
77
+ bias=False,
78
+ )
79
+ # key, value projections
80
+ self.kv_proj = nn.Linear(
81
+ config.n_embd ,
82
+ 2 * config.n_embd ,
83
+ bias=False,
84
+ )
85
+ # output projection
86
+ self.c_proj = nn.Linear(
87
+ config.n_embd ,
88
+ config.n_embd ,
89
+ bias=False,
90
+ )
91
+ # regularization
92
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
93
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
94
+ # output projection
95
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
96
+ # causal mask to ensure that attention is only applied to the left in the input sequence
97
+ num = int(bool(config.num_props)) + int(config.scaffold_maxlen) #int(config.lstm_layers) # int(config.scaffold)
98
+ # num = 1
99
+ self.register_buffer("mask", torch.tril(torch.ones(config.block_size + num, config.block_size + num))
100
+ .view(1, 1, config.block_size + num, config.block_size + num))
101
+
102
+ self.n_head = config.n_head
103
+ self.n_embd = config.n_embd
104
+
105
+ def forward(self, x, layer_past=None):
106
+ B, T, C = x.size()
107
+
108
+ q = self.q_proj(x)
109
+ k, v = self.kv_proj(x).split(self.n_embd, dim=2)
110
+
111
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
112
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(
113
+ 1, 2
114
+ ) # (B, nh, T, hs)
115
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(
116
+ 1, 2
117
+ ) # (B, nh, T, hs)
118
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(
119
+ 1, 2
120
+ )
121
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
122
+ # y = F.scaled_dot_product_attention(
123
+ # q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True
124
+ # )
125
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
126
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
127
+ att = F.softmax(att, dim=-1)
128
+ attn_save = att
129
+ att = self.attn_drop(att)
130
+ y = att @ v
131
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
132
+
133
+ # output projection
134
+ y = self.c_proj(y)
135
+
136
+ return y, attn_save
137
+
138
+ def find_multiple(n , k ) :
139
+ if n % k == 0:
140
+ return n
141
+ return n + k - (n % k)
142
+
143
+
144
+ class MLP(nn.Module):
145
+ def __init__(self, config ) :
146
+ super().__init__()
147
+ hidden_dim = 4 * config.n_embd * config.n_head
148
+ n_hidden = int(2 * hidden_dim / 3)
149
+ n_hidden = find_multiple(n_hidden, 256)
150
+
151
+ self.c_fc1 = nn.Linear(
152
+ config.n_embd , n_hidden, bias=False
153
+ )
154
+ self.c_fc2 = nn.Linear(
155
+ config.n_embd , n_hidden, bias=False
156
+ )
157
+ self.c_proj = nn.Linear(
158
+ n_hidden, config.n_embd , bias=False
159
+ )
160
+
161
+ def forward(self, x):
162
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
163
+ x = self.c_proj(x)
164
+ return x
165
+
166
+ class Block(nn.Module):
167
+ """ an unassuming Transformer block """
168
+
169
+ def __init__(self, config):
170
+ super().__init__()
171
+ self.rms_1 = RMSNorm(config.n_embd )
172
+ self.rms_2 = RMSNorm(config.n_embd )
173
+ self.ln1 = nn.LayerNorm(config.n_embd)
174
+ self.ln2 = nn.LayerNorm(config.n_embd)
175
+ self.attn = CausalSelfAttention(config)
176
+ self.mlp = MLP(config)
177
+ def forward(self, x):
178
+ y, attn = self.attn(self.rms_1(x))
179
+ x = x + y
180
+ x = x + self.mlp(self.rms_2(x))
181
+ return x, attn
182
+
183
+ class GPT(nn.Module):
184
+ """ the full GPT language model, with a context size of block_size """
185
+
186
+ def __init__(self, config):
187
+ super().__init__()
188
+
189
+ # input embedding stem
190
+ self.config = config
191
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
192
+ self.type_emb = nn.Embedding(2, config.n_embd)
193
+ if config.num_props:
194
+ self.prop_nn = nn.Linear(config.num_props, config.n_embd)
195
+
196
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
197
+ self.drop = nn.Dropout(config.embd_pdrop)
198
+ # transformer
199
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
200
+ # decoder head
201
+ self.ln_f = RMSNorm(config.n_embd )
202
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
203
+
204
+ self.block_size = config.block_size
205
+
206
+ if config.lstm:
207
+ self.lstm = nn.LSTM(input_size = config.n_embd, hidden_size = config.n_embd, num_layers = config.lstm_layers, dropout = 0.3, bidirectional = False)
208
+ self.apply(self._init_weights)
209
+
210
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
211
+
212
+ def get_block_size(self):
213
+ return self.block_size
214
+
215
+ def _init_weights(self, module):
216
+ if isinstance(module, nn.Linear):
217
+ torch.nn.init.normal_(
218
+ module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)
219
+ )
220
+ elif isinstance(module, nn.Embedding):
221
+ torch.nn.init.normal_(
222
+ module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)
223
+ )
224
+
225
+ def configure_optimizers(self, parameters, train_config):
226
+
227
+ optimizer = torch.optim.AdamW(parameters, lr=train_config.learning_rate, betas=train_config.betas)
228
+ return optimizer
229
+
230
+ def forward(self, idx, targets=None, prop = None, scaffold = None):
231
+ b, t = idx.size()
232
+
233
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
234
+
235
+ if self.config.num_props:
236
+ assert prop.size(-1) == self.config.num_props, "Num_props should be equal to last dim of property vector"
237
+
238
+ # forward the GPT model
239
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
240
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
241
+ type_embeddings = self.type_emb(torch.ones((
242
+ b,t), dtype = torch.long, device = idx.device))
243
+ x = self.drop(token_embeddings + position_embeddings + type_embeddings)
244
+
245
+ if self.config.num_props:
246
+ type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
247
+ if prop.ndim == 2:
248
+ p = self.prop_nn(prop.unsqueeze(1)) # for single property
249
+ else:
250
+ p = self.prop_nn(prop) # for multiproperty
251
+ p += type_embd
252
+ x = torch.cat([p, x], 1)
253
+
254
+ if self.config.scaffold:
255
+ type_embd = self.type_emb(torch.zeros((b, 1), dtype = torch.long, device = idx.device))
256
+
257
+ scaffold_embeds = self.tok_emb(scaffold) # .mean(1, keepdim = True)
258
+ if self.config.lstm:
259
+ scaffold_embeds = self.lstm(scaffold_embeds.permute(1,0,2))[1][0]
260
+ # scaffold_embeds = scaffold_embeds.reshape(scaffold_embeds.shape[1], scaffold_embeds.shape[0], 2, self.config.n_embd).mean(2)
261
+ scaffold_embeds = scaffold_embeds.permute(1,0,2) # mean(0, keepdim = True)
262
+ # scaffold_embeds = scaffold_embeds.reshape(self.config.lstm_layers, 1, -1, self.config.n_embd)[-1].permute(1,0,2)
263
+ # scaffold_embeds = scaffold_embeds.reshape(scaffold_embeds.shape[1], scaffold_embeds.shape[0], self.config.n_embd)
264
+ scaffold_embeds += type_embd
265
+ x = torch.cat([scaffold_embeds, x], 1)
266
+
267
+ # x = self.blocks(x)
268
+ attn_maps = []
269
+
270
+ for layer in self.blocks:
271
+ x, attn = layer(x)
272
+ attn_maps.append(attn)
273
+
274
+ x = self.ln_f(x)
275
+ logits = self.head(x)
276
+
277
+
278
+ if self.config.num_props and self.config.scaffold:
279
+ num = int(bool(self.config.num_props)) + int(self.config.scaffold_maxlen)
280
+ elif self.config.num_props:
281
+ num = int(bool(self.config.num_props))
282
+ elif self.config.scaffold:
283
+ num = int(self.config.scaffold_maxlen)
284
+ else:
285
+ num = 0
286
+
287
+ logits = logits[:, num:, :]
288
+
289
+ # if self.config.num_props or self.config.scaffold:
290
+
291
+ # num = int(bool(self.config.num_props)) + int(self.config.scaffold_maxlen) #int(self.config.lstm_layers) # int(self.config.scaffold) # int(self.config.scaffold)
292
+
293
+
294
+ # print(logits.shape)
295
+
296
+ # if we are given some desired targets also calculate the loss
297
+ loss = None
298
+ if targets is not None:
299
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1))
300
+
301
+ return logits, loss, attn_maps # (num_layers, batch_size, num_heads, max_seq_len, max_seq_len)
tool/comget/ppcenos.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<": 0, "[#Branch1]": 1, "[#Branch2]": 2, "[#C]": 3, "[#N]": 4, "[=Branch1]": 5, "[=Branch2]": 6, "[=C]": 7, "[=N]": 8, "[=O]": 9, "[=Ring1]": 10, "[=Ring2]": 11, "[=S]": 12, "[Branch1]": 13, "[Branch2]": 14, "[C]": 15, "[Cl]": 16, "[F]": 17, "[GeH2]": 18, "[Ge]": 19, "[NH1]": 20, "[N]": 21, "[O]": 22, "[P]": 23, "[Ring1]": 24, "[Ring2]": 25, "[S]": 26, "[Se]": 27, "[nop]": 28}
tool/comget/ppcenos.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ddee4e16df14ee00e9736c66755977774edf1259c46afb03469c99ca7659fbf5
3
+ size 160173846
tool/comget/utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .moses.utils import get_mol
7
+ from rdkit import Chem
8
+
9
+ import numpy as np
10
+ import threading
11
+
12
+ def set_seed(seed):
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+
18
+ def top_k_logits(logits, k):
19
+ v, ix = torch.topk(logits, k)
20
+ out = logits.clone()
21
+ out[out < v[:, [-1]]] = -float('Inf')
22
+ return out
23
+
24
+ @torch.no_grad()
25
+ def sample(model, x, steps, temperature=1.0, sample=False, top_k=None, prop = None, scaffold = None):
26
+ """
27
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
28
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
29
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
30
+ of block_size, unlike an RNN that has an infinite context window.
31
+ """
32
+ block_size = model.get_block_size()
33
+ model.eval()
34
+
35
+ for k in range(steps):
36
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
37
+ logits, _, _ = model(x_cond, prop = prop, scaffold = scaffold) # for liggpt
38
+ # logits, _, _ = model(x_cond) # for char_rnn
39
+ # pluck the logits at the final step and scale by temperature
40
+ logits = logits[:, -1, :] / temperature
41
+ # optionally crop probabilities to only the top k options
42
+ if top_k is not None:
43
+ logits = top_k_logits(logits, top_k)
44
+ # apply softmax to convert to probabilities
45
+ probs = F.softmax(logits, dim=-1)
46
+ # sample from the distribution or take the most likely
47
+ if sample:
48
+ ix = torch.multinomial(probs, num_samples=1)
49
+ else:
50
+ _, ix = torch.topk(probs, k=1, dim=-1)
51
+ # append to the sequence and continue
52
+ x = torch.cat((x, ix), dim=1)
53
+
54
+ return x
55
+
56
+ def check_novelty(gen_smiles, train_smiles): # gen: say 788, train: 120803
57
+ if len(gen_smiles) == 0:
58
+ novel_ratio = 0.
59
+ else:
60
+ duplicates = [1 for mol in gen_smiles if mol in train_smiles] # [1]*45
61
+ novel = len(gen_smiles) - sum(duplicates) # 788-45=743
62
+ novel_ratio = novel*100./len(gen_smiles) # 743*100/788=94.289
63
+ print("novelty: {:.3f}%".format(novel_ratio))
64
+ return novel_ratio
65
+
66
+ def canonic_smiles(smiles_or_mol):
67
+ mol = get_mol(smiles_or_mol)
68
+ if mol is None:
69
+ return None
70
+ return Chem.MolToSmiles(mol)
71
+
72
+ #Experimental Class for Smiles Enumeration, Iterator and SmilesIterator adapted from Keras 1.2.2
73
+
74
+ class Iterator(object):
75
+ """Abstract base class for data iterators.
76
+ # Arguments
77
+ n: Integer, total number of samples in the dataset to loop over.
78
+ batch_size: Integer, size of a batch.
79
+ shuffle: Boolean, whether to shuffle the data between epochs.
80
+ seed: Random seeding for data shuffling.
81
+ """
82
+
83
+ def __init__(self, n, batch_size, shuffle, seed):
84
+ self.n = n
85
+ self.batch_size = batch_size
86
+ self.shuffle = shuffle
87
+ self.batch_index = 0
88
+ self.total_batches_seen = 0
89
+ self.lock = threading.Lock()
90
+ self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
91
+ if n < batch_size:
92
+ raise ValueError('Input data length is shorter than batch_size\nAdjust batch_size')
93
+
94
+ def reset(self):
95
+ self.batch_index = 0
96
+
97
+ def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
98
+ # Ensure self.batch_index is 0.
99
+ self.reset()
100
+ while 1:
101
+ if seed is not None:
102
+ np.random.seed(seed + self.total_batches_seen)
103
+ if self.batch_index == 0:
104
+ index_array = np.arange(n)
105
+ if shuffle:
106
+ index_array = np.random.permutation(n)
107
+
108
+ current_index = (self.batch_index * batch_size) % n
109
+ if n > current_index + batch_size:
110
+ current_batch_size = batch_size
111
+ self.batch_index += 1
112
+ else:
113
+ current_batch_size = n - current_index
114
+ self.batch_index = 0
115
+ self.total_batches_seen += 1
116
+ yield (index_array[current_index: current_index + current_batch_size],
117
+ current_index, current_batch_size)
118
+
119
+ def __iter__(self):
120
+ # Needed if we want to do something like:
121
+ # for x, y in data_gen.flow(...):
122
+ return self
123
+
124
+ def __next__(self, *args, **kwargs):
125
+ return self.next(*args, **kwargs)
126
+
127
+
128
+
129
+
130
+ class SmilesIterator(Iterator):
131
+ """Iterator yielding data from a SMILES array.
132
+ # Arguments
133
+ x: Numpy array of SMILES input data.
134
+ y: Numpy array of targets data.
135
+ smiles_data_generator: Instance of `SmilesEnumerator`
136
+ to use for random SMILES generation.
137
+ batch_size: Integer, size of a batch.
138
+ shuffle: Boolean, whether to shuffle the data between epochs.
139
+ seed: Random seed for data shuffling.
140
+ dtype: dtype to use for returned batch. Set to keras.backend.floatx if using Keras
141
+ """
142
+
143
+ def __init__(self, x, y, smiles_data_generator,
144
+ batch_size=32, shuffle=False, seed=None,
145
+ dtype=np.float32
146
+ ):
147
+ if y is not None and len(x) != len(y):
148
+ raise ValueError('X (images tensor) and y (labels) '
149
+ 'should have the same length. '
150
+ 'Found: X.shape = %s, y.shape = %s' %
151
+ (np.asarray(x).shape, np.asarray(y).shape))
152
+
153
+ self.x = np.asarray(x)
154
+
155
+ if y is not None:
156
+ self.y = np.asarray(y)
157
+ else:
158
+ self.y = None
159
+ self.smiles_data_generator = smiles_data_generator
160
+ self.dtype = dtype
161
+ super(SmilesIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)
162
+
163
+ def next(self):
164
+ """For python 2.x.
165
+ # Returns
166
+ The next batch.
167
+ """
168
+ # Keeps under lock only the mechanism which advances
169
+ # the indexing of each batch.
170
+ with self.lock:
171
+ index_array, current_index, current_batch_size = next(self.index_generator)
172
+ # The transformation of images is not under thread lock
173
+ # so it can be done in parallel
174
+ batch_x = np.zeros(tuple([current_batch_size] + [ self.smiles_data_generator.pad, self.smiles_data_generator._charlen]), dtype=self.dtype)
175
+ for i, j in enumerate(index_array):
176
+ smiles = self.x[j:j+1]
177
+ x = self.smiles_data_generator.transform(smiles)
178
+ batch_x[i] = x
179
+
180
+ if self.y is None:
181
+ return batch_x
182
+ batch_y = self.y[index_array]
183
+ return batch_x, batch_y
184
+
185
+
186
+ class SmilesEnumerator(object):
187
+ """SMILES Enumerator, vectorizer and devectorizer
188
+
189
+ #Arguments
190
+ charset: string containing the characters for the vectorization
191
+ can also be generated via the .fit() method
192
+ pad: Length of the vectorization
193
+ leftpad: Add spaces to the left of the SMILES
194
+ isomericSmiles: Generate SMILES containing information about stereogenic centers
195
+ enum: Enumerate the SMILES during transform
196
+ canonical: use canonical SMILES during transform (overrides enum)
197
+ """
198
+ def __init__(self, charset = '@C)(=cOn1S2/H[N]\\', pad=120, leftpad=True, isomericSmiles=True, enum=True, canonical=False):
199
+ self._charset = None
200
+ self.charset = charset
201
+ self.pad = pad
202
+ self.leftpad = leftpad
203
+ self.isomericSmiles = isomericSmiles
204
+ self.enumerate = enum
205
+ self.canonical = canonical
206
+
207
+ @property
208
+ def charset(self):
209
+ return self._charset
210
+
211
+ @charset.setter
212
+ def charset(self, charset):
213
+ self._charset = charset
214
+ self._charlen = len(charset)
215
+ self._char_to_int = dict((c,i) for i,c in enumerate(charset))
216
+ self._int_to_char = dict((i,c) for i,c in enumerate(charset))
217
+
218
+ def fit(self, smiles, extra_chars=[], extra_pad = 5):
219
+ """Performs extraction of the charset and length of a SMILES datasets and sets self.pad and self.charset
220
+
221
+ #Arguments
222
+ smiles: Numpy array or Pandas series containing smiles as strings
223
+ extra_chars: List of extra chars to add to the charset (e.g. "\\\\" when "/" is present)
224
+ extra_pad: Extra padding to add before or after the SMILES vectorization
225
+ """
226
+ charset = set("".join(list(smiles)))
227
+ self.charset = "".join(charset.union(set(extra_chars)))
228
+ self.pad = max([len(smile) for smile in smiles]) + extra_pad
229
+
230
+ def randomize_smiles(self, smiles):
231
+ """Perform a randomization of a SMILES string
232
+ must be RDKit sanitizable"""
233
+ m = Chem.MolFromSmiles(smiles)
234
+ ans = list(range(m.GetNumAtoms()))
235
+ np.random.shuffle(ans)
236
+ nm = Chem.RenumberAtoms(m,ans)
237
+ return Chem.MolToSmiles(nm, canonical=self.canonical, isomericSmiles=self.isomericSmiles)
238
+
239
+ def transform(self, smiles):
240
+ """Perform an enumeration (randomization) and vectorization of a Numpy array of smiles strings
241
+ #Arguments
242
+ smiles: Numpy array or Pandas series containing smiles as strings
243
+ """
244
+ one_hot = np.zeros((smiles.shape[0], self.pad, self._charlen),dtype=np.int8)
245
+
246
+ if self.leftpad:
247
+ for i,ss in enumerate(smiles):
248
+ if self.enumerate: ss = self.randomize_smiles(ss)
249
+ l = len(ss)
250
+ diff = self.pad - l
251
+ for j,c in enumerate(ss):
252
+ one_hot[i,j+diff,self._char_to_int[c]] = 1
253
+ return one_hot
254
+ else:
255
+ for i,ss in enumerate(smiles):
256
+ if self.enumerate: ss = self.randomize_smiles(ss)
257
+ for j,c in enumerate(ss):
258
+ one_hot[i,j,self._char_to_int[c]] = 1
259
+ return one_hot
260
+
261
+
262
+ def reverse_transform(self, vect):
263
+ """ Performs a conversion of a vectorized SMILES to a smiles strings
264
+ charset must be the same as used for vectorization.
265
+ #Arguments
266
+ vect: Numpy array of vectorized SMILES.
267
+ """
268
+ smiles = []
269
+ for v in vect:
270
+ #mask v
271
+ v=v[v.sum(axis=1)==1]
272
+ #Find one hot encoded index with argmax, translate to char and join to string
273
+ smile = "".join(self._int_to_char[i] for i in v.argmax(axis=1))
274
+ smiles.append(smile)
275
+ return np.array(smiles)
tool/converters.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.tools import BaseTool
2
+
3
+ from tool.chemspace import ChemSpace
4
+ import pandas as pd
5
+
6
+ from utils import (
7
+ is_multiple_smiles,
8
+ is_smiles,
9
+ pubchem_query2smiles,
10
+ query2cas,
11
+ smiles2name,
12
+ )
13
+
14
+
15
+ class Query2CAS(BaseTool):
16
+ name:str = "Mol2CAS"
17
+ description:str = "Input molecule (name or SMILES), returns CAS number."
18
+ url_cid: str = None
19
+ url_data: str = None
20
+
21
+
22
+ def __init__(
23
+ self,
24
+ ):
25
+ super().__init__()
26
+ self.url_cid = (
27
+ "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/{}/{}/cids/JSON"
28
+ )
29
+ self.url_data = (
30
+ "https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/data/compound/{}/JSON"
31
+ )
32
+
33
+ def _run(self, query: str) -> str:
34
+ try:
35
+ # if query is smiles
36
+ smiles = None
37
+ if is_smiles(query):
38
+ smiles = query
39
+ try:
40
+ cas = query2cas(query, self.url_cid, self.url_data)
41
+ except ValueError as e:
42
+ return str(e)
43
+ if smiles is None:
44
+ try:
45
+ smiles = pubchem_query2smiles(cas, None)
46
+ except ValueError as e:
47
+ return str(e)
48
+
49
+ return cas
50
+ except ValueError:
51
+ return "CAS number not found"
52
+
53
+ async def _arun(self, query: str) -> str:
54
+ """Use the tool asynchronously."""
55
+ raise NotImplementedError()
56
+
57
+
58
+ class Query2SMILES(BaseTool):
59
+ name:str = "CAS2SMILES"
60
+ description :str = "Input a CAS number, returns SMILES."
61
+ url: str = None
62
+ chemspace_api_key: str = None
63
+
64
+ def __init__(self, chemspace_api_key: str = None):
65
+ super().__init__()
66
+ self.chemspace_api_key = chemspace_api_key
67
+ self.url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{}/{}"
68
+
69
+ def _run(self, query: str) -> str:
70
+ """This function queries the given molecule name and returns a SMILES string from the record"""
71
+ """Useful to get the SMILES string of one molecule by searching the name of a molecule. Only query with one specific name."""
72
+ if is_smiles(query) and is_multiple_smiles(query):
73
+ return "Multiple SMILES strings detected, input one molecule at a time."
74
+ try:
75
+ smi = pubchem_query2smiles(query, self.url)
76
+ except Exception as e:
77
+ if self.chemspace_api_key:
78
+ try:
79
+ chemspace = ChemSpace(self.chemspace_api_key)
80
+ smi = chemspace.convert_mol_rep(query, "smiles")
81
+ smi = smi.split(":")[1]
82
+ except Exception:
83
+ return str(e)
84
+ else:
85
+ try:
86
+
87
+ smi = chemspace.convert_mol_rep(query, "smiles")
88
+ smi = smi.split(":")[1]
89
+ except Exception:
90
+ return str(e)
91
+
92
+
93
+ return smi
94
+
95
+ async def _arun(self, query: str) -> str:
96
+ """Use the tool asynchronously."""
97
+ raise NotImplementedError()
98
+
99
+ class Mol2SMILES(BaseTool):
100
+ name:str = "Mol2SMILES"
101
+ description :str = "Input a molecular name , returns SMILES."
102
+
103
+ def __init__(self, chemspace_api_key: str = None):
104
+ super().__init__()
105
+
106
+ def _run(self, query: str) -> str:
107
+ """This function queries the given molecule name and returns a SMILES string from the record"""
108
+ """Useful to get the SMILES string of one molecule by searching the name of a molecule. Only query with one specific name."""
109
+ if is_smiles(query) and is_multiple_smiles(query):
110
+ return "Multiple SMILES strings detected, input one molecule at a time."
111
+ try:
112
+ smi = pubchem_query2smiles(query )
113
+ return smi
114
+ except Exception as e:
115
+ try:
116
+ csv_data = pd.read_csv('tool/dataset.csv',encoding='ISO-8859-1')
117
+ relevant_rows = csv_data[csv_data['Name']==(query)]
118
+ if not relevant_rows.empty:
119
+ # Get the most relevant answer (assuming we return the first match)
120
+ return relevant_rows.iloc[0]['SMILES']
121
+ except:
122
+ return str(e)
123
+
124
+ async def _arun(self, query: str) -> str:
125
+ """Use the tool asynchronously."""
126
+ raise NotImplementedError()
127
+
128
+ class SMILES2Name(BaseTool):
129
+ name:str = "SMILES2Name"
130
+ description:str = "Input SMILES, returns molecule name."
131
+
132
+
133
+
134
+ def __init__(self):
135
+ super().__init__()
136
+
137
+ def _run(self, query: str) -> str:
138
+ """Use the tool."""
139
+ try:
140
+ if not is_smiles(query):
141
+ try:
142
+ query2smiles = Query2SMILES()
143
+ query = query2smiles.run(query)
144
+ except:
145
+ raise ValueError("Invalid molecule input, no Pubchem entry")
146
+ name = smiles2name(query)
147
+
148
+ return name
149
+ except Exception as e:
150
+ return "Error: " + str(e)
151
+
152
+ async def _arun(self, query: str) -> str:
153
+ """Use the tool asynchronously."""
154
+ raise NotImplementedError()
tool/csv_search.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Dec 23 16:18:29 2024
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+ from langchain.tools import BaseTool
9
+ import pandas as pd
10
+
11
+ class search_csv(BaseTool):
12
+ name = "csvsearch"
13
+ description = (
14
+ "input name, return the SMILES of materials "
15
+ "convert name to SMILES."
16
+ )
17
+ llm: BaseLanguageModel = None
18
+ openai_api_key: str = None
19
+ semantic_scholar_api_key: str = None
20
+
21
+ def __init__(self):
22
+ super().__init__()
23
+ def _run(self, smiles: str) -> str:
24
+ csv_data = pd.read_csv('dataset.csv',encoding='ISO-8859-1')
25
+ relevant_rows = csv_data[csv_data['Name']==(query)]
26
+
27
+ if not relevant_rows.empty:
28
+ # Get the most relevant answer (assuming we return the first match)
29
+ return relevant_rows.iloc[0]['SMILES']
30
+ return None
31
+ async def _arun(self, smiles: str) -> str:
32
+ """Use the tool asynchronously."""
33
+ raise NotImplementedError()
34
+
tool/dap/.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
tool/dap/OSC/test.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1081233b2f0b3c77752a98b3c9e4ae065cb21aae4e3e5d31f8d673a1c2069ded
3
+ size 81596523
tool/dap/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # database
tool/dap/config/config_hparam.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ { "name": "biomarker_log",
2
+
3
+ "d_model_name" : "DeepChem/ChemBERTa-10M-MTR",
4
+ "p_model_name" : "DeepChem/ChemBERTa-77M-MLM",
5
+ "gpu_ids" : "0",
6
+ "model_mode" : "train",
7
+ "load_checkpoint" : "./checkpoint/bindingDB/test.ckpt",
8
+
9
+ "prot_maxlength" : 360,
10
+ "layer_limit" : true,
11
+
12
+ "max_epoch": 16,
13
+ "batch_size": 40,
14
+ "num_workers": 0,
15
+
16
+ "task_name" : "OSC",
17
+ "lr": 1e-4,
18
+ "layer_features" : [512, 128, 64, 1],
19
+ "dropout" : 0.1,
20
+ "loss_fn" : "MSE",
21
+
22
+ "traindata_rate" : 1.0,
23
+ "pretrained": {"chem":true, "prot":true},
24
+ "num_seed" : 111
25
+ }
26
+
tool/dap/config/predict.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ { "name": "biomarker_log",
2
+
3
+ "d_model_name" : "DeepChem/ChemBERTa-10M-MLM",
4
+ "p_model_name" : "DeepChem/ChemBERTa-10M-MTR",
5
+ "gpu_ids" : "0",
6
+ "model_mode" : "test",
7
+ "load_checkpoint" : "tool/dap/OSC/test.ckpt",
8
+
9
+ "prot_maxlength" : 360,
10
+ "layer_limit" : true,
11
+
12
+ "max_epoch": 16,
13
+ "batch_size": 40,
14
+ "num_workers": 0,
15
+
16
+ "task_name" : "OSC",
17
+ "lr": 1e-4,
18
+ "layer_features" : [128, 128, 128, 1],
19
+ "dropout" : 0.1,
20
+ "loss_fn" : "MSE",
21
+
22
+ "traindata_rate" : 1.0,
23
+ "pretrained": {"chem":true, "prot":true},
24
+ "num_seed" : 111
25
+ }
26
+
tool/dap/requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ altair
2
+ streamlit
3
+ streamlit-ketcher
4
+ torch
5
+ tqdm
6
+ transformers
7
+ pytorch_lightning
8
+ scipy
9
+ pandas
10
+ rdkit
11
+ scikit-learn
12
+ matplotlib
13
+ easydict
14
+ wandb
15
+ networkx
16
+ seaborn
17
+
18
+
tool/dap/run.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from transformers import AutoTokenizer
7
+
8
+ from .util.utils import *
9
+ from rdkit import Chem
10
+ from tqdm import tqdm
11
+ from .train import markerModel
12
+
13
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0 '
15
+
16
+ device_count = torch.cuda.device_count()
17
+ device_biomarker = torch.device('cuda' if torch.cuda.is_available() else "cpu")
18
+
19
+ device = torch.device('cpu')
20
+ a_model_name = 'DeepChem/ChemBERTa-10M-MLM'
21
+ d_model_name = 'DeepChem/ChemBERTa-10M-MTR'
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(a_model_name)
24
+ d_tokenizer = AutoTokenizer.from_pretrained(d_model_name)
25
+
26
+ #--biomarker Model
27
+ ##-- hyper param config file Load --##
28
+ config = load_hparams('tool/dap/config/predict.json')
29
+ config = DictX(config)
30
+ model = markerModel(config.d_model_name, config.p_model_name,
31
+ config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
32
+
33
+ model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False)
34
+ model.eval()
35
+ model.freeze()
36
+
37
+ if device_biomarker.type == 'cuda':
38
+ model = torch.nn.DataParallel(model)
39
+
40
+ def get_marker(drug_inputs, prot_inputs):
41
+ output_preds = model(drug_inputs, prot_inputs)
42
+
43
+ predict = torch.squeeze( (output_preds)).tolist()
44
+
45
+ # output_preds = torch.relu(output_preds)
46
+ # predict = torch.tanh(output_preds)
47
+ # predict = predict.squeeze(dim=1).tolist()
48
+
49
+ return predict
50
+
51
+
52
+ def marker_prediction(smiles, aas):
53
+ try:
54
+ aas_input = []
55
+ for ass_data in aas:
56
+ aas_input.append(' '.join(list(ass_data)))
57
+
58
+ a_inputs = tokenizer(smiles, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
59
+ # d_inputs = tokenizer(smiles, truncation=True, return_tensors="pt")
60
+ a_input_ids = a_inputs['input_ids'].to(device)
61
+ a_attention_mask = a_inputs['attention_mask'].to(device)
62
+ a_inputs = {'input_ids': a_input_ids, 'attention_mask': a_attention_mask}
63
+
64
+ d_inputs = d_tokenizer(aas_input, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
65
+ # p_inputs = prot_tokenizer(aas_input, truncation=True, return_tensors="pt")
66
+ d_input_ids = d_inputs['input_ids'].to(device)
67
+ d_attention_mask = d_inputs['attention_mask'].to(device)
68
+ d_inputs = {'input_ids': d_input_ids, 'attention_mask': d_attention_mask}
69
+
70
+ output_list = get_marker(a_inputs, d_inputs)
71
+
72
+
73
+ return output_list
74
+
75
+ except Exception as e:
76
+ print(e)
77
+ return {'Error_message': e}
78
+
79
+
80
+ def smiles_aas_test(smile_acc,smile_don):
81
+
82
+ mola = Chem.MolFromSmiles(smile_acc)
83
+ smile_acc = Chem.MolToSmiles(mola, canonical=True)
84
+ mold = Chem.MolFromSmiles(smile_don)
85
+ smile_don = Chem.MolToSmiles(mold, canonical=True)
86
+
87
+ batch_size = 1
88
+
89
+ datas = []
90
+ marker_list = []
91
+ marker_datas = []
92
+
93
+
94
+
95
+ marker_datas.append([smile_acc,smile_don])
96
+ if len(marker_datas) == batch_size:
97
+ marker_list.append(list(marker_datas))
98
+ marker_datas.clear()
99
+
100
+ if len(marker_datas) != 0:
101
+ marker_list.append(list(marker_datas))
102
+ marker_datas.clear()
103
+
104
+ for marker_datas in tqdm(marker_list, total=len(marker_list)):
105
+ smiles_d , smiles_a = zip(*marker_datas)
106
+ output_pred = marker_prediction(list(smiles_d), list(smiles_a) )
107
+ if len(datas) == 0:
108
+ datas = output_pred
109
+ else:
110
+ datas = datas + output_pred
111
+
112
+ # ## -- Export result data to csv -- ##
113
+ # df = pd.DataFrame(datas)
114
+ # df.to_csv('./results/predictData_nontonon_bindingdb_test.csv', index=None)
115
+
116
+ # print(df)
117
+ return datas
118
+
119
+
120
+
121
+ if __name__ == '__main__':
122
+
123
+ a = smiles_aas_test('CC(C)CCCC(C)CCC1=C(/C=C2\C(=O)C3=C(C=C(F)C(F)=C3)C2=C(C#N)C#N)SC2=C1N(CCC(C)CCCC(C)C)C1=C2C2=NSN=C2C2=C1N(CCC(C)CCCC(C)C)C1=C2SC(/C=C2\C(=O)C3=C(C=C(F)C(F)=C3)C2=C(C#N)C#N)=C1CCC(C)CCCC(C)C','CCCCC(CC)CC1=C(F)C=C(C2=C3C=C(C4=CC=C(C5=C6C(=O)C7=C(CC(CC)CCCC)SC(CC(CC)CCCC)=C7C(=O)C6=C(C6=CC=C(C)S6)S5)S4)SC3=C(C3=CC(F)=C(CC(CC)CCCC)S3)C3=C2SC(C)=C3)S1')
124
+
tool/dap/screen.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from transformers import AutoTokenizer
7
+ from rdkit import Chem
8
+ from .util.utils import *
9
+
10
+ from tqdm import tqdm
11
+ from .train import markerModel
12
+
13
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0 '
15
+
16
+ device_count = torch.cuda.device_count()
17
+ device_biberta= torch.device('cuda' if torch.cuda.is_available() else "cpu")
18
+
19
+ device = torch.device('cpu')
20
+ a_model_name = 'DeepChem/ChemBERTa-10M-MLM'
21
+ d_model_name = 'DeepChem/ChemBERTa-10M-MTR'
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(a_model_name)
24
+ d_tokenizer = AutoTokenizer.from_pretrained(d_model_name)
25
+
26
+ #--bibertaModel
27
+ ##-- hyper param config file Load --##
28
+
29
+ config = load_hparams('tool/dap/config/predict.json')
30
+ config = DictX(config)
31
+ model = markerModel(config.d_model_name, config.p_model_name,
32
+ config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
33
+
34
+ model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False)
35
+ model.eval()
36
+ model.freeze()
37
+
38
+ if device_biberta.type == 'cuda':
39
+ model = torch.nn.DataParallel(model)
40
+
41
+ def get_biberta(drug_inputs, prot_inputs):
42
+ output_preds = model(drug_inputs, prot_inputs)
43
+
44
+ predict = torch.squeeze( (output_preds)).tolist()
45
+
46
+ # output_preds = torch.relu(output_preds)
47
+ # predict = torch.tanh(output_preds)
48
+ # predict = predict.squeeze(dim=1).tolist()
49
+
50
+ return predict
51
+
52
+
53
+ def biberta_prediction(smiles, aas):
54
+ try:
55
+ aas_input = []
56
+ for ass_data in aas:
57
+ aas_input.append(' '.join(list(ass_data)))
58
+
59
+ a_inputs = tokenizer(smiles, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
60
+ # d_inputs = tokenizer(smiles, truncation=True, return_tensors="pt")
61
+ a_input_ids = a_inputs['input_ids'].to(device)
62
+ a_attention_mask = a_inputs['attention_mask'].to(device)
63
+ a_inputs = {'input_ids': a_input_ids, 'attention_mask': a_attention_mask}
64
+
65
+ d_inputs = d_tokenizer(aas_input, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
66
+ # p_inputs = prot_tokenizer(aas_input, truncation=True, return_tensors="pt")
67
+ d_input_ids = d_inputs['input_ids'].to(device)
68
+ d_attention_mask = d_inputs['attention_mask'].to(device)
69
+ d_inputs = {'input_ids': d_input_ids, 'attention_mask': d_attention_mask}
70
+
71
+ output_predict = get_biberta(a_inputs, d_inputs)
72
+
73
+ output_list = [{'acceptor': smiles[i], 'donor': aas[i], 'predict': output_predict[i]} for i in range(0,len(aas))]
74
+
75
+ return output_list
76
+
77
+ except Exception as e:
78
+ print(e)
79
+ return {'Error_message': e}
80
+
81
+
82
+ def smiles_aas_test(file):
83
+
84
+ batch_se = 80
85
+ try:
86
+ datas = []
87
+ biberta_list = []
88
+ biberta_datas = []
89
+
90
+ smiles_aas = pd.read_csv(file)
91
+
92
+ smiles_d , smiles_a = (smiles_aas['donor'],smiles_aas['acceptor'])
93
+
94
+ donor,acceptor =[],[]
95
+ for i in smiles_d:
96
+ s = Chem.MolToSmiles(Chem.MolFromSmiles(i))
97
+ donor.append(s)
98
+ for i in smiles_a:
99
+ s = Chem.MolToSmiles(Chem.MolFromSmiles(i))
100
+ acceptor.append(s)
101
+
102
+ output_pred = biberta_prediction(list(acceptor), list(donor) )
103
+ if len(datas) == 0:
104
+ datas = output_pred
105
+ else:
106
+ datas = datas + output_pred
107
+
108
+ # ## -- Export result data to csv -- ##
109
+ df = pd.DataFrame(datas)
110
+
111
+
112
+ # print(df)
113
+ return datas
114
+
115
+ except Exception as e:
116
+ print(e)
117
+ return {'Error_message': e}
118
+
tool/dap/train.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3
+
4
+ import gc, os
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from scipy.stats import pearsonr
9
+ from .util.utils import *
10
+ #from .util.attention_flow import *
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ import sklearn as sk
16
+ from torch.utils.data import Dataset, DataLoader
17
+
18
+ import pytorch_lightning as pl
19
+ from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
20
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
21
+ from transformers import AutoConfig, AutoTokenizer, RobertaModel, BertModel
22
+ from sklearn.metrics import r2_score, mean_absolute_error,mean_squared_error
23
+
24
+ class markerDataset(Dataset):
25
+ def __init__(self, list_IDs, labels, df_dti, d_tokenizer, p_tokenizer):
26
+ 'Initialization'
27
+ self.labels = labels
28
+ self.list_IDs = list_IDs
29
+ self.df = df_dti
30
+
31
+ self.d_tokenizer = d_tokenizer
32
+ self.p_tokenizer = p_tokenizer
33
+
34
+
35
+
36
+ def convert_data(self, acc_data, don_data):
37
+
38
+
39
+ d_inputs = self.d_tokenizer(acc_data, return_tensors="pt")
40
+ p_inputs = self.d_tokenizer(don_data, return_tensors="pt")
41
+
42
+ acc_input_ids = d_inputs['input_ids']
43
+ acc_attention_mask = d_inputs['attention_mask']
44
+ acc_inputs = {'input_ids': acc_input_ids, 'attention_mask': acc_attention_mask}
45
+
46
+ don_input_ids = p_inputs['input_ids']
47
+ don_attention_mask = p_inputs['attention_mask']
48
+ don_inputs = {'input_ids': don_input_ids, 'attention_mask': don_attention_mask}
49
+
50
+ return acc_inputs, don_inputs
51
+
52
+ def tokenize_data(self, acc_data, don_data):
53
+
54
+ tokenize_acc = ['[CLS]'] + self.d_tokenizer.tokenize(acc_data) + ['[SEP]']
55
+
56
+ tokenize_don = ['[CLS]'] + self.p_tokenizer.tokenize(don_data) + ['[SEP]']
57
+
58
+ return tokenize_acc, tokenize_don
59
+
60
+ def __len__(self):
61
+ 'Denotes the total number of samples'
62
+ return len(self.list_IDs)
63
+
64
+ def __getitem__(self, index):
65
+ 'Generates one sample of data'
66
+ index = self.list_IDs[index]
67
+ acc_data = self.df.iloc[index]['acceptor']
68
+ don_data = self.df.iloc[index]['donor']
69
+
70
+ d_inputs = self.d_tokenizer(acc_data, padding='max_length', max_length=400, truncation=True, return_tensors="pt")
71
+ p_inputs = self.p_tokenizer(don_data, padding='max_length', max_length=400, truncation=True, return_tensors="pt")
72
+
73
+ d_input_ids = d_inputs['input_ids'].squeeze()
74
+ d_attention_mask = d_inputs['attention_mask'].squeeze()
75
+ p_input_ids = p_inputs['input_ids'].squeeze()
76
+ p_attention_mask = p_inputs['attention_mask'].squeeze()
77
+
78
+ labels = torch.as_tensor(self.labels[index], dtype=torch.float)
79
+
80
+ dataset = [d_input_ids, d_attention_mask, p_input_ids, p_attention_mask, labels]
81
+ return dataset
82
+
83
+
84
+ class markerDataModule(pl.LightningDataModule):
85
+ def __init__(self, task_name, acc_model_name, don_model_name, num_workers, batch_size, traindata_rate = 1.0):
86
+ super().__init__()
87
+ self.batch_size = batch_size
88
+ self.num_workers = num_workers
89
+ self.task_name = task_name
90
+
91
+ self.traindata_rate = traindata_rate
92
+
93
+ self.d_tokenizer = AutoTokenizer.from_pretrained(acc_model_name)
94
+ self.p_tokenizer = AutoTokenizer.from_pretrained(don_model_name)
95
+
96
+ self.df_train = None
97
+ self.df_val = None
98
+ self.df_test = None
99
+
100
+ self.load_testData = True
101
+
102
+ self.train_dataset = None
103
+ self.valid_dataset = None
104
+ self.test_dataset = None
105
+
106
+ def get_task(self, task_name):
107
+ if task_name.lower() == 'OSC':
108
+ return './dataset/OSC/'
109
+
110
+ elif task_name.lower() == 'merge':
111
+ self.load_testData = False
112
+ return './dataset/MergeDataset'
113
+
114
+ def prepare_data(self):
115
+ # Use this method to do things that might write to disk or that need to be done only from
116
+ # a single process in distributed settings.
117
+ dataFolder = './dataset/OSC'
118
+
119
+ self.df_train = pd.read_csv(dataFolder + '/train.csv')
120
+ self.df_val = pd.read_csv(dataFolder + '/val.csv')
121
+
122
+ ## -- Data Lenght Rate apply -- ##
123
+ traindata_length = int(len(self.df_train) * self.traindata_rate)
124
+ validdata_length = int(len(self.df_val) * self.traindata_rate)
125
+
126
+ self.df_train = self.df_train[:traindata_length]
127
+ self.df_val = self.df_val[:validdata_length]
128
+
129
+ if self.load_testData is True:
130
+ self.df_test = pd.read_csv(dataFolder + '/test.csv')
131
+
132
+ def setup(self, stage=None):
133
+ if stage == 'fit' or stage is None:
134
+ self.train_dataset = markerDataset(self.df_train.index.values, self.df_train.Label.values, self.df_train,
135
+ self.d_tokenizer, self.p_tokenizer)
136
+ self.valid_dataset = markerDataset(self.df_val.index.values, self.df_val.Label.values, self.df_val,
137
+ self.d_tokenizer, self.p_tokenizer)
138
+
139
+ if self.load_testData is True:
140
+ self.test_dataset = markerDataset(self.df_test.index.values, self.df_test.Label.values, self.df_test,
141
+ self.d_tokenizer, self.p_tokenizer)
142
+
143
+ def train_dataloader(self):
144
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
145
+
146
+ def val_dataloader(self):
147
+ return DataLoader(self.valid_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
148
+
149
+ def test_dataloader(self):
150
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
151
+
152
+
153
+ class markerModel(pl.LightningModule):
154
+ def __init__(self, acc_model_name, don_model_name, lr, dropout, layer_features, loss_fn = "smooth", layer_limit = True, d_pretrained=True, p_pretrained=True):
155
+ super().__init__()
156
+ self.lr = lr
157
+ self.loss_fn = loss_fn
158
+ self.criterion = torch.nn.MSELoss()
159
+ self.criterion_smooth = torch.nn.SmoothL1Loss()
160
+ # self.sigmoid = nn.Sigmoid()
161
+
162
+ #-- Pretrained Model Setting
163
+ acc_config = AutoConfig.from_pretrained("seyonec/SMILES_BPE_PubChem_100k_shard00")
164
+ if d_pretrained is False:
165
+ self.d_model = RobertaModel(acc_config)
166
+ print('acceptor model without pretraining')
167
+ else:
168
+ self.d_model = RobertaModel.from_pretrained(acc_model_name, num_labels=2,
169
+ output_hidden_states=True,
170
+ output_attentions=True)
171
+
172
+ don_config = AutoConfig.from_pretrained("seyonec/SMILES_BPE_PubChem_100k_shard00")
173
+
174
+ if p_pretrained is False:
175
+ self.p_model = RobertaModel(don_config)
176
+ print('donor model without pretraining')
177
+ else:
178
+ self.p_model = RobertaModel.from_pretrained(don_model_name,
179
+ output_hidden_states=True,
180
+ output_attentions=True)
181
+
182
+ #-- Decoder Layer Setting
183
+ layers = []
184
+ firstfeature = self.d_model.config.hidden_size + self.p_model.config.hidden_size
185
+ for feature_idx in range(0, len(layer_features) - 1):
186
+ layers.append(nn.Linear(firstfeature, layer_features[feature_idx]))
187
+ firstfeature = layer_features[feature_idx]
188
+
189
+ if feature_idx is len(layer_features)-2:
190
+ layers.append(nn.ReLU())
191
+ else:
192
+ layers.append(nn.ReLU())
193
+
194
+ if dropout > 0:
195
+ layers.append(nn.Dropout(dropout))
196
+
197
+ layers.append(nn.Linear(firstfeature, layer_features[-1]))
198
+
199
+ self.decoder = nn.Sequential(*layers)
200
+
201
+ self.save_hyperparameters()
202
+
203
+ def forward(self, acc_inputs, don_inputs):
204
+
205
+ d_outputs = self.d_model(acc_inputs['input_ids'], acc_inputs['attention_mask'])
206
+ p_outputs = self.p_model(don_inputs['input_ids'], don_inputs['attention_mask'])
207
+
208
+ outs = torch.cat((d_outputs.last_hidden_state[:, 0], p_outputs.last_hidden_state[:, 0]), dim=1)
209
+ outs = self.decoder(outs)
210
+
211
+ return outs
212
+
213
+ def attention_output(self, acc_inputs, don_inputs):
214
+
215
+ d_outputs = self.d_model(acc_inputs['input_ids'], acc_inputs['attention_mask'])
216
+ p_outputs = self.p_model(don_inputs['input_ids'], don_inputs['attention_mask'])
217
+
218
+ outs = torch.cat((d_outputs.last_hidden_state[:, 0], p_outputs.last_hidden_state[:, 0]), dim=1)
219
+ outs = self.decoder(outs)
220
+
221
+ return d_outputs['attentions'], p_outputs['attentions'], outs
222
+
223
+ def training_step(self, batch, batch_idx):
224
+
225
+ acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
226
+
227
+ don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
228
+
229
+ labels = batch[4]
230
+
231
+ output = self(acc_inputs, don_inputs)
232
+ logits = output.squeeze(dim=1)
233
+
234
+ if self.loss_fn == 'MSE':
235
+ loss = self.criterion(logits, labels)
236
+ else:
237
+ loss = self.criterion_smooth(logits, labels)
238
+
239
+ self.log("train_loss", loss, on_step=False, on_epoch=True, logger=True)
240
+ # print("train_loss", loss)
241
+ return {"loss": loss}
242
+
243
+ def validation_step(self, batch, batch_idx):
244
+ acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
245
+ don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
246
+ labels = batch[4]
247
+
248
+ output = self(acc_inputs, don_inputs)
249
+ logits = output.squeeze(dim=1)
250
+
251
+
252
+ if self.loss_fn == 'MSE':
253
+ loss = self.criterion(logits, labels)
254
+ else:
255
+ loss = self.criterion_smooth(logits, labels)
256
+
257
+ self.log("valid_loss", loss, on_step=False, on_epoch=True, logger=True)
258
+ # print("valid_loss", loss)
259
+ return {"logits": logits, "labels": labels}
260
+
261
+ def validation_step_end(self, outputs):
262
+ return {"logits": outputs['logits'], "labels": outputs['labels']}
263
+
264
+ def validation_epoch_end(self, outputs):
265
+ preds = self.convert_outputs_to_preds(outputs)
266
+ labels = torch.as_tensor(torch.cat([output['labels'] for output in outputs], dim=0), dtype=torch.int)
267
+
268
+ mae, mse, r2,r = self.log_score(preds, labels)
269
+
270
+ self.log("mae", mae, on_step=False, on_epoch=True, logger=True)
271
+ self.log("mse", mse, on_step=False, on_epoch=True, logger=True)
272
+
273
+ self.log("r2", r2, on_step=False, on_epoch=True, logger=True)
274
+
275
+ def test_step(self, batch, batch_idx):
276
+ acc_inputs = {'input_ids': batch[0], 'attention_mask': batch[1]}
277
+ don_inputs = {'input_ids': batch[2], 'attention_mask': batch[3]}
278
+ labels = batch[4]
279
+
280
+ output = self(acc_inputs, don_inputs)
281
+ logits = output.squeeze(dim=1)
282
+
283
+ if self.loss_fn == 'MSE':
284
+ loss = self.criterion(logits, labels)
285
+ else:
286
+ loss = self.criterion_smooth(logits, labels)
287
+
288
+ self.log("test_loss", loss, on_step=False, on_epoch=True, logger=True)
289
+ return {"logits": logits, "labels": labels}
290
+
291
+ def test_step_end(self, outputs):
292
+ return {"logits": outputs['logits'], "labels": outputs['labels']}
293
+
294
+ def test_epoch_end(self, outputs):
295
+ preds = self.convert_outputs_to_preds(outputs)
296
+ labels = torch.as_tensor(torch.cat([output['labels'] for output in outputs], dim=0), dtype=torch.int)
297
+
298
+ mae, mse, r2,r = self.log_score(preds, labels)
299
+
300
+ self.log("mae", mae, on_step=False, on_epoch=True, logger=True)
301
+ self.log("mse", mse, on_step=False, on_epoch=True, logger=True)
302
+ self.log("r2", r2, on_step=False, on_epoch=True, logger=True)
303
+ self.log("r", r, on_step=False, on_epoch=True, logger=True)
304
+ def configure_optimizers(self):
305
+
306
+ param_optimizer = list(self.named_parameters())
307
+
308
+ no_decay = ["bias", "gamma", "beta"]
309
+ optimizer_grouped_parameters = [
310
+ {
311
+ "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
312
+ "weight_decay_rate": 0.0001
313
+ },
314
+ {
315
+ "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
316
+ "weight_decay_rate": 0.0
317
+ },
318
+ ]
319
+ optimizer = torch.optim.AdamW(
320
+ optimizer_grouped_parameters,
321
+ lr=self.lr,
322
+ )
323
+ return optimizer
324
+
325
+ def convert_outputs_to_preds(self, outputs):
326
+ logits = torch.cat([output['logits'] for output in outputs], dim=0)
327
+ return logits
328
+
329
+ def log_score(self, preds, labels):
330
+ y_pred = preds.detach().cpu().numpy()
331
+ y_label = labels.detach().cpu().numpy()
332
+
333
+ mae = mean_absolute_error(y_label, y_pred)
334
+ mse = mean_squared_error(y_label, y_pred)
335
+ r2=r2_score(y_label, y_pred)
336
+ r = pearsonr(y_label, y_pred)
337
+ print(f'\nmae : {mae}')
338
+ print(f'mse : {mse}')
339
+ print(f'r2 : {r2}')
340
+ print(f'r : {r}')
341
+
342
+ return mae, mse, r2, r
343
+
344
+
345
+ def main_wandb(config=None):
346
+ try:
347
+ if config is not None:
348
+ wandb.init(config=config, project=project_name)
349
+ else:
350
+ wandb.init(settings=wandb.Settings(console='off'))
351
+
352
+ config = wandb.config
353
+ pl.seed_everything(seed=config.num_seed)
354
+
355
+ dm = markerDataModule(config.task_name, config.d_model_name, config.p_model_name,
356
+ config.num_workers, config.batch_size, config.prot_maxlength, config.traindata_rate)
357
+ dm.prepare_data()
358
+ dm.setup()
359
+
360
+ model_type = str(config.pretrained['chem'])+"To"+str(config.pretrained['prot'])
361
+ #model_logger = WandbLogger(project=project_name)
362
+ checkpoint_callback = ModelCheckpoint(f"{config.task_name}_{model_type}_{config.lr}_{config.num_seed}", save_top_k=1, monitor="mae", mode="max")
363
+
364
+ trainer = pl.Trainer(
365
+ max_epochs=config.max_epoch,
366
+ precision=16,
367
+ #logger=model_logger,
368
+ callbacks=[checkpoint_callback],
369
+ accelerator='cpu',log_every_n_steps=40
370
+ )
371
+
372
+
373
+ if config.model_mode == "train":
374
+ model = markerModel(config.d_model_name, config.p_model_name,
375
+ config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
376
+ model.train()
377
+ trainer.fit(model, datamodule=dm)
378
+
379
+ model.eval()
380
+ trainer.test(model, datamodule=dm)
381
+
382
+ else:
383
+ model = markerModel.load_from_checkpoint(config.load_checkpoint)
384
+
385
+ model.eval()
386
+ trainer.test(model, datamodule=dm)
387
+
388
+ except Exception as e:
389
+ print(e)
390
+
391
+
392
+ def main_default(config):
393
+ try:
394
+ config = DictX(config)
395
+ pl.seed_everything(seed=config.num_seed)
396
+
397
+ dm = markerDataModule(config.task_name, config.d_model_name, config.p_model_name,
398
+ config.num_workers, config.batch_size, config.traindata_rate)
399
+
400
+ dm.prepare_data()
401
+ dm.setup()
402
+ model_type = str(config.pretrained['chem'])+"To"+str(config.pretrained['prot'])
403
+ # model_logger = TensorBoardLogger("./log", name=f"{config.task_name}_{model_type}_{config.num_seed}")
404
+ checkpoint_callback = ModelCheckpoint(f"{config.task_name}_{model_type}_{config.lr}_{config.num_seed}", save_top_k=1, monitor="mse", mode="max")
405
+
406
+ trainer = pl.Trainer(
407
+ max_epochs=config.max_epoch,
408
+ precision= 32,
409
+ # logger=model_logger,
410
+ callbacks=[checkpoint_callback],
411
+ accelerator='cpu',log_every_n_steps=40
412
+ )
413
+
414
+
415
+ if config.model_mode == "train":
416
+ model = markerModel(config.d_model_name, config.p_model_name,
417
+ config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
418
+
419
+ model.train()
420
+
421
+ trainer.fit(model, datamodule=dm)
422
+
423
+ model.eval()
424
+ trainer.test(model, datamodule=dm)
425
+
426
+ else:
427
+ model = markerModel.load_from_checkpoint(config.load_checkpoint)
428
+
429
+ model.eval()
430
+ trainer.test(model, datamodule=dm)
431
+ except Exception as e:
432
+ print(e)
433
+
434
+
435
+ if __name__ == '__main__':
436
+ using_wandb = False
437
+
438
+ if using_wandb == True:
439
+ #-- hyper param config file Load --##
440
+ config = load_hparams('config/config_hparam.json')
441
+ project_name = config["name"]
442
+
443
+ main_wandb(config)
444
+
445
+ ##-- wandb Sweep Hyper Param Tuning --##
446
+ # config = load_hparams('config/config_sweep_bindingDB.json')
447
+ # project_name = config["name"]
448
+ # sweep_id = wandb.sweep(config, project=project_name)
449
+ # wandb.agent(sweep_id, main_wandb)
450
+
451
+ else:
452
+ config = load_hparams('config/config_hparam.json')
453
+
454
+ main_default(config)
tool/dap/util/attention_flow.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ import matplotlib.pyplot as plt
6
+
7
+ import seaborn as sns
8
+ import itertools
9
+ import matplotlib as mpl
10
+ # import cugraph as cnx
11
+
12
+ rc={'font.size': 10, 'axes.labelsize': 10, 'legend.fontsize': 10.0,
13
+ 'axes.titlesize': 32, 'xtick.labelsize': 20, 'ytick.labelsize': 16}
14
+ plt.rcParams.update(**rc)
15
+ mpl.rcParams['axes.linewidth'] = .5 #set the value globally
16
+
17
+
18
+ def plot_attention_heatmap(att, s_position, t_positions, input_tokens):
19
+
20
+ cls_att = np.flip(att[:,s_position, t_positions], axis=0)
21
+ xticklb = list(itertools.compress(input_tokens, [i in t_positions for i in np.arange(len(input_tokens))]))
22
+ yticklb = [str(i) if i%2 ==0 else '' for i in np.arange(att.shape[0],0, -1)]
23
+ ax = sns.heatmap(cls_att, xticklabels=xticklb, yticklabels=yticklb, cmap="YlOrRd")
24
+
25
+ return ax
26
+
27
+ def convert_adjmat_tomats(adjmat, n_layers, l):
28
+ mats = np.zeros((n_layers,l,l))
29
+
30
+ for i in np.arange(n_layers):
31
+ mats[i] = adjmat[(i+1)*l:(i+2)*l,i*l:(i+1)*l]
32
+
33
+ return mats
34
+
35
+ def make_residual_attention(attentions):
36
+ all_attention = [att.detach().cpu().numpy() for att in attentions]
37
+ attentions_mat = np.asarray(all_attention)[:,0]
38
+
39
+ res_att_mat = attentions_mat.sum(axis=1)/attentions_mat.shape[1]
40
+ res_att_mat = res_att_mat + np.eye(res_att_mat.shape[1])[None,...]
41
+ res_att_mat = res_att_mat / res_att_mat.sum(axis=-1)[...,None]
42
+
43
+ return attentions_mat, res_att_mat
44
+
45
+ ## -------------------------------------------------------- ##
46
+ ## -- Make flow network (No Print Node - edge Connection)-- ##
47
+ ## -------------------------------------------------------- ##
48
+
49
+ def make_flow_network(mat, input_tokens):
50
+ n_layers, length, _ = mat.shape
51
+ adj_mat = np.zeros(((n_layers+1)*length, (n_layers+1)*length))
52
+ labels_to_index = {}
53
+ for k in np.arange(length):
54
+ labels_to_index[str(k)+"_"+input_tokens[k]] = k
55
+
56
+ for i in np.arange(1,n_layers+1):
57
+ for k_f in np.arange(length):
58
+ index_from = (i)*length+k_f
59
+ label = "L"+str(i)+"_"+str(k_f)
60
+ labels_to_index[label] = index_from
61
+ for k_t in np.arange(length):
62
+ index_to = (i-1)*length+k_t
63
+ adj_mat[index_from][index_to] = mat[i-1][k_f][k_t]
64
+
65
+ net_graph=nx.from_numpy_matrix(adj_mat, create_using=nx.DiGraph())
66
+ for i in np.arange(adj_mat.shape[0]):
67
+ for j in np.arange(adj_mat.shape[1]):
68
+ nx.set_edge_attributes(net_graph, {(i,j): adj_mat[i,j]}, 'capacity')
69
+
70
+ return net_graph, labels_to_index
71
+
72
+
73
+ def make_input_node(attention_mat, res_labels_to_index):
74
+ input_nodes = []
75
+ for key in res_labels_to_index:
76
+ if res_labels_to_index[key] < attention_mat.shape[-1]:
77
+ input_nodes.append(key)
78
+
79
+ return input_nodes
80
+ ## ------------------------------------------------ ##
81
+ ## -- Draw Attention flow node - Edge Connection -- ##
82
+ ## ------------------------------------------------ ##
83
+
84
+ ##-- networkx graph Initation and Calculation flow --##
85
+ def get_adjmat(mat, input_tokens):
86
+ n_layers, length, _ = mat.shape
87
+ adj_mat = np.zeros(((n_layers+1)*length, (n_layers+1)*length))
88
+ labels_to_index = {}
89
+ for k in np.arange(length):
90
+ labels_to_index[str(k)+"_"+input_tokens[k]] = k
91
+
92
+ for i in np.arange(1,n_layers+1):
93
+ for k_f in np.arange(length):
94
+ index_from = (i)*length+k_f
95
+ label = "L"+str(i)+"_"+str(k_f)
96
+ labels_to_index[label] = index_from
97
+ for k_t in np.arange(length):
98
+ index_to = (i-1)*length+k_t
99
+ adj_mat[index_from][index_to] = mat[i-1][k_f][k_t]
100
+
101
+ return adj_mat, labels_to_index
102
+
103
+ def draw_attention_graph(adjmat, labels_to_index, n_layers, length):
104
+ A = adjmat
105
+ net_graph=nx.from_numpy_matrix(A, create_using=nx.DiGraph())
106
+ for i in np.arange(A.shape[0]):
107
+ for j in np.arange(A.shape[1]):
108
+ nx.set_edge_attributes(net_graph, {(i,j): A[i,j]}, 'capacity')
109
+
110
+ pos = {}
111
+ label_pos = {}
112
+ for i in np.arange(n_layers+1):
113
+ for k_f in np.arange(length):
114
+ pos[i*length+k_f] = ((i+0.4)*2, length - k_f)
115
+ label_pos[i*length+k_f] = (i*2, length - k_f)
116
+
117
+ index_to_labels = {}
118
+ for key in labels_to_index:
119
+ index_to_labels[labels_to_index[key]] = key.split("_")[-1]
120
+ if labels_to_index[key] >= length:
121
+ index_to_labels[labels_to_index[key]] = ''
122
+
123
+ #plt.figure(1,figsize=(20,12))
124
+ nx.draw_networkx_nodes(net_graph,pos,node_color='green', labels=index_to_labels, node_size=50)
125
+ nx.draw_networkx_labels(net_graph,pos=label_pos, labels=index_to_labels, font_size=18)
126
+
127
+ all_weights = []
128
+ #4 a. Iterate through the graph nodes to gather all the weights
129
+ for (node1,node2,data) in net_graph.edges(data=True):
130
+ all_weights.append(data['weight']) #we'll use this when determining edge thickness
131
+
132
+ #4 b. Get unique weights
133
+ unique_weights = list(set(all_weights))
134
+
135
+ #4 c. Plot the edges - one by one!
136
+ for weight in unique_weights:
137
+ #4 d. Form a filtered list with just the weight you want to draw
138
+ weighted_edges = [(node1,node2) for (node1,node2,edge_attr) in net_graph.edges(data=True) if edge_attr['weight']==weight]
139
+ #4 e. I think multiplying by [num_nodes/sum(all_weights)] makes the graphs edges look cleaner
140
+
141
+ w = weight #(weight - min(all_weights))/(max(all_weights) - min(all_weights))
142
+ width = w
143
+ nx.draw_networkx_edges(net_graph,pos,edgelist=weighted_edges,width=width, edge_color='darkblue')
144
+
145
+ return net_graph
146
+
147
+ def compute_flows(G, labels_to_index, input_nodes, length):
148
+ number_of_nodes = len(labels_to_index)
149
+ flow_values=np.zeros((number_of_nodes,number_of_nodes))
150
+ for key in tqdm(labels_to_index, desc="flow algorithms", total=len(labels_to_index)):
151
+ if key not in input_nodes:
152
+ current_layer = int(labels_to_index[key] / length)
153
+ pre_layer = current_layer - 1
154
+ u = labels_to_index[key]
155
+ for inp_node_key in input_nodes:
156
+ v = labels_to_index[inp_node_key]
157
+ flow_value = nx.maximum_flow_value(G,u,v, flow_func=nx.algorithms.flow.edmonds_karp)
158
+ # flow_value = cnx
159
+ flow_values[u][pre_layer*length+v ] = flow_value
160
+ flow_values[u] /= flow_values[u].sum()
161
+
162
+ return flow_values
163
+
164
+ def compute_node_flow(G, labels_to_index, input_nodes, output_nodes,length):
165
+ number_of_nodes = len(labels_to_index)
166
+ flow_values=np.zeros((number_of_nodes,number_of_nodes))
167
+ for key in output_nodes:
168
+ if key not in input_nodes:
169
+ current_layer = int(labels_to_index[key] / length)
170
+ pre_layer = current_layer - 1
171
+ u = labels_to_index[key]
172
+ for inp_node_key in input_nodes:
173
+ v = labels_to_index[inp_node_key]
174
+ flow_value = nx.maximum_flow_value(G,u,v, flow_func=nx.algorithms.flow.edmonds_karp)
175
+ flow_values[u][pre_layer*length+v ] = flow_value
176
+ flow_values[u] /= flow_values[u].sum()
177
+
178
+ return flow_values
179
+
180
+ def compute_joint_attention(att_mat, add_residual=True):
181
+ if add_residual:
182
+ residual_att = np.eye(att_mat.shape[1])[None,...]
183
+ aug_att_mat = att_mat + residual_att
184
+ aug_att_mat = aug_att_mat / aug_att_mat.sum(axis=-1)[...,None]
185
+ else:
186
+ aug_att_mat = att_mat
187
+
188
+ joint_attentions = np.zeros(aug_att_mat.shape)
189
+
190
+ layers = joint_attentions.shape[0]
191
+ joint_attentions[0] = aug_att_mat[0]
192
+ for i in np.arange(1,layers):
193
+ joint_attentions[i] = aug_att_mat[i].dot(joint_attentions[i-1])
194
+
195
+ return joint_attentions
tool/dap/util/attention_plot.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+
6
+ def make_attention_table(att, tokens, numb, token_idx = 0, layerNumb = -1):
7
+ token_att = att[layerNumb, token_idx, range(1, len(tokens))]
8
+
9
+ token_label=[]
10
+ token_numb=[]
11
+ for idx, token in enumerate(tokens[1:]) :
12
+ token_label.append(f"<b>{token}</b>")
13
+ token_numb.append(f"{idx}")
14
+
15
+ pair = list(zip(token_numb, token_att))
16
+
17
+ df = pd.DataFrame(pair, columns=["Amino acid", "Attention rate"])
18
+ df.to_csv(f"amino_acid_seq_attention_{numb}.csv", index=None)
19
+
20
+ top3_idx = sorted(range(len(token_att)), key=lambda i: token_att[i], reverse=True)[:3]
21
+
22
+ colors = ['cornflowerblue', ] * len(token_numb)
23
+
24
+ for i in top3_idx:
25
+ colors[i] = 'crimson'
26
+
27
+ fig = go.Figure(data=[go.Bar(
28
+ x=df["Amino acid"],
29
+ y=df["Attention rate"],
30
+ # range_y=[min(token_att), max(token_att)],
31
+ marker_color=colors # marker color can be a single color value or an iterable
32
+ )])
33
+
34
+ # fig = px.histogram(df, x="Amino acid", y="Attention rate", range_y=[min(token_att), max(token_att)])
35
+
36
+ fig.update_layout(plot_bgcolor="white")
37
+ fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
38
+ fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
39
+ fig.update_layout(title={'text': "<b>Attention rate of amino acid sequence token</b>",
40
+ 'font':{'size':40},
41
+ 'y': 0.96,
42
+ 'x': 0.5,
43
+ 'xanchor': 'center',
44
+ 'yanchor': 'top'},
45
+
46
+ xaxis=dict(tickmode='array',
47
+ tickvals=token_numb,
48
+ ticktext=token_label
49
+ ),
50
+
51
+ xaxis_title={'text': "Amino acid sequence",
52
+ 'font':{'size':30}},
53
+ yaxis_title={'text': "Attention rate",
54
+ 'font':{'size':30}},
55
+
56
+ font=dict(family="Calibri, monospace",
57
+ size=17
58
+ ))
59
+
60
+ fig.write_image(f'figures/Amino_acid_seq_{numb}.png', width=1.5*1200, height=0.75*1200, scale=2)
61
+ fig.show()
62
+
63
+
64
+ def read_attention():
65
+ df = pd.read_csv("../amino_acid_seq_attention.csv")
66
+ # d_flow_values = np.asarray(d_read_flow_values)
67
+
68
+ fig = px.bar(df, x="Amino acid", y="Attention rate", range_y=[min(df["Attention rate"]), max(df["Attention rate"])])
69
+
70
+ fig.update_layout(plot_bgcolor="white")
71
+ fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
72
+ fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
73
+ fig.update_layout(title={'text': "<b>Attention rate of amino acid sequence token</b>",
74
+ 'font':{'size':40},
75
+ 'y': 0.96,
76
+ 'x': 0.5,
77
+ 'xanchor': 'center',
78
+ 'yanchor': 'top'},
79
+
80
+ xaxis_title={'text': "Amino acid sequence",
81
+ 'font':{'size':30}},
82
+ yaxis_title={'text': "Attention rate",
83
+ 'font':{'size':30}},
84
+
85
+ font=dict(family="Calibri, monospace",
86
+ size=17
87
+ ))
88
+
89
+ fig.write_image('figures/Amino_acid_seq.png', width=1.5*1200, height=0.75*1200, scale=2)
90
+ fig.show()
91
+
92
+ if __name__ == '__main__':
93
+ read_attention()
tool/dap/util/boxplot.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+
4
+ from scipy import stats
5
+ import plotly.express as px
6
+
7
+ from plotly.subplots import make_subplots
8
+ import plotly.graph_objects as go
9
+
10
+ ROC = 1
11
+ PR = 2
12
+
13
+ def add_p_value_annotation(fig, array_columns, subplot=None, _format=dict(interline=0.03, text_height=1.03, color='black')):
14
+ ''' Adds notations giving the p-value between two box plot data (t-test two-sided comparison)
15
+
16
+ Parameters:
17
+ ----------
18
+ fig: figure
19
+ plotly boxplot figure
20
+ array_columns: np.array
21
+ array of which columns to compare
22
+ e.g.: [[0,1], [1,2]] compares column 0 with 1 and 1 with 2
23
+ subplot: None or int
24
+ specifies if the figures has subplots and what subplot to add the notation to
25
+ _format: dict
26
+ format characteristics for the lines
27
+
28
+ Returns:
29
+ -------
30
+ fig: figure
31
+ figure with the added notation
32
+ '''
33
+ # Specify in what y_range to plot for each pair of columns
34
+ y_range = np.zeros([len(array_columns), 2])
35
+ for i in range(len(array_columns)):
36
+ y_range[i] = [1.03+i*_format['interline'], 1.04+i*_format['interline']]
37
+
38
+ # Get values from figure
39
+ fig_dict = fig.to_dict()
40
+
41
+ # Get indices if working with subplots
42
+ if subplot:
43
+ if subplot == 1:
44
+ subplot_str = ''
45
+ else:
46
+ subplot_str =str(subplot)
47
+ indices = [] #Change the box index to the indices of the data for that subplot
48
+ for index, data in enumerate(fig_dict['data']):
49
+ #print(index, data['xaxis'], 'x' + subplot_str)
50
+ if data['xaxis'] == 'x' + subplot_str:
51
+ indices = np.append(indices, index)
52
+ indices = [int(i) for i in indices]
53
+ print((indices))
54
+ else:
55
+ subplot_str = ''
56
+
57
+ # Print the p-values
58
+ for index, column_pair in enumerate(array_columns):
59
+ if subplot:
60
+ data_pair = [indices[column_pair[0]], indices[column_pair[1]]]
61
+ else:
62
+ data_pair = column_pair
63
+
64
+ # Mare sure it is selecting the data and subplot you want
65
+ #print('0:', fig_dict['data'][data_pair[0]]['name'], fig_dict['data'][data_pair[0]]['xaxis'])
66
+ #print('1:', fig_dict['data'][data_pair[1]]['name'], fig_dict['data'][data_pair[1]]['xaxis'])
67
+
68
+ # Get the p-value
69
+ pvalue = stats.ttest_ind(
70
+ fig_dict['data'][data_pair[0]]['y'],
71
+ fig_dict['data'][data_pair[1]]['y'],
72
+ equal_var=False,
73
+ )[1]
74
+ if pvalue >= 0.05:
75
+ symbol = 'ns'
76
+ elif pvalue >= 0.01:
77
+ symbol = '*'
78
+ elif pvalue >= 0.001:
79
+ symbol = '**'
80
+ else:
81
+ symbol = '***'
82
+ # Vertical line
83
+ fig.add_shape(type="line",
84
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
85
+ x0=column_pair[0], y0=y_range[index][0],
86
+ x1=column_pair[0], y1=y_range[index][1],
87
+ line=dict(color=_format['color'], width=1.5,)
88
+ )
89
+ # Horizontal line
90
+ fig.add_shape(type="line",
91
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
92
+ x0=column_pair[0], y0=y_range[index][1],
93
+ x1=column_pair[1], y1=y_range[index][1],
94
+ line=dict(color=_format['color'], width=1.5,)
95
+ )
96
+ # Vertical line
97
+ fig.add_shape(type="line",
98
+ xref="x"+subplot_str, yref="y"+subplot_str+" domain",
99
+ x0=column_pair[1], y0=y_range[index][0],
100
+ x1=column_pair[1], y1=y_range[index][1],
101
+ line=dict(color=_format['color'], width=1.5,)
102
+ )
103
+ ## add text at the correct x, y coordinates
104
+ ## for bars, there is a direct mapping from the bar number to 0, 1, 2...
105
+ fig.add_annotation(dict(font=dict(color=_format['color'],size=14),
106
+ x=(column_pair[0] + column_pair[1])/2,
107
+ y=y_range[index][1]*_format['text_height'],
108
+ showarrow=False,
109
+ text=symbol,
110
+ textangle=0,
111
+ xref="x"+subplot_str,
112
+ yref="y"+subplot_str+" domain"
113
+ ))
114
+ return fig
115
+
116
+
117
+ def box_plot(df):
118
+
119
+ fig = px.box(df, x = 'Task_name', y='test_auroc', color="Model")
120
+
121
+ fig.update_layout(plot_bgcolor="white")
122
+ fig.update_xaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0)',mirror=False)
123
+ fig.update_yaxes(linecolor='rgba(0,0,0,0.25)', gridcolor='rgba(0,0,0,0.07)',mirror=False)
124
+ fig.update_layout(title={'text': "<b>ROC-AUC score distribution</b>",
125
+ 'font':{'size':40},
126
+ 'y': 0.96,
127
+ 'x': 0.5,
128
+ 'xanchor': 'center',
129
+ 'yanchor': 'top'},
130
+
131
+ xaxis_title={'text': "Datasets",
132
+ 'font':{'size':30}},
133
+ yaxis_title={'text': "ROC-AUC",
134
+ 'font':{'size':30}},
135
+
136
+ font=dict(family="Calibri, monospace",
137
+ size=17
138
+ ))
139
+
140
+ fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1)
141
+
142
+ fig.write_image('../figures/box_plot_integration.png', width=1.5*1200, height=0.75*1200, scale=2)
143
+ fig.show()
144
+
145
+
146
+
147
+ def go_box_plot(df, metric = ROC):
148
+ dataset_list = ['BIOSNAP', 'DAVIS', 'BindingDB']
149
+ model_list = ['LR', 'DNN', 'GNN-CPI', 'DeepDTI', 'DeepDTA', 'DeepConv-DTI', 'Moltrans', 'ours']
150
+ clr_list = ['red', 'orange', 'green', 'indianred', 'lightseagreen', 'goldenrod', 'magenta', 'blue']
151
+
152
+ if metric == ROC:
153
+ # fig_title = "<b>ROC-AUC score distribution</b>"
154
+ file_title = "boxplot_auroc.png"
155
+ select_metric = "test_auroc"
156
+ else:
157
+ # fig_title = "<b>PR-AUC score distribution</b>"
158
+ file_title = "boxplot_auprc.png"
159
+ select_metric = "test_auprc"
160
+
161
+ fig = make_subplots(rows=1, cols=3, subplot_titles=[c for c in dataset_list])
162
+
163
+ groups = df.groupby(df.Task_name)
164
+ Legand = True
165
+
166
+ for dataset_idx, dataset in enumerate(dataset_list):
167
+ df_modelgroup = groups.get_group(dataset)
168
+ model_groups = df_modelgroup.groupby(df_modelgroup.Model)
169
+ if dataset_idx != 0:
170
+ Legand = False
171
+ for model_idx, model in enumerate(model_list):
172
+ df_data = model_groups.get_group(model)
173
+ fig.append_trace(go.Box(y=df_data[select_metric],
174
+ name=model,
175
+ marker_color=clr_list[model_idx],
176
+ showlegend = Legand
177
+ ),
178
+ row=1,
179
+ col=dataset_idx+1)
180
+
181
+
182
+
183
+
184
+ # fig.update_layout(title={'text': fig_title,
185
+ # 'font':{'size':25},
186
+ # 'y': 0.98,
187
+ # 'x': 0.46,
188
+ # 'xanchor': 'center',
189
+ # 'yanchor': 'top'})
190
+
191
+ # fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=1)
192
+ # fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=2)
193
+ # fig = add_p_value_annotation(fig, [[0,7], [3,7], [6,7]], subplot=3)
194
+
195
+ fig.write_image(f'../figures/{file_title}', width=1.5*1200, height=0.75*1200, scale=2)
196
+ fig.show()
197
+
198
+
199
+ if __name__ == '__main__':
200
+ df = pd.read_csv("../dataset/wandb_export_boxplotdata.csv")
201
+ box_plot(df)
tool/dap/util/data/bindingdb_kd.tab ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72a38ae07a75d5d4c269d2776b6e62e0edde29ff7cf8a323158c08951f808d1
3
+ size 54432102
tool/dap/util/data/davis.tab ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d4c6809dcb7c5da2b91a32d594d6935b75484940bde4d18055eb5e1059262f4
3
+ size 21376712
tool/dap/util/emetric.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def get_cindex(Y, P):
4
+ summ = 0
5
+ pair = 0
6
+
7
+ for i in range(1, len(Y)):
8
+ for j in range(0, i):
9
+ if i is not j:
10
+ if(Y[i] > Y[j]):
11
+ pair +=1
12
+ summ += 1* (P[i] > P[j]) + 0.5 * (P[i] == P[j])
13
+
14
+
15
+ if pair is not 0:
16
+ return summ/pair
17
+ else:
18
+ return 0
19
+
20
+
21
+ def r_squared_error(y_obs,y_pred):
22
+ y_obs = np.array(y_obs)
23
+ y_pred = np.array(y_pred)
24
+ y_obs_mean = [np.mean(y_obs) for y in y_obs]
25
+ y_pred_mean = [np.mean(y_pred) for y in y_pred]
26
+
27
+ mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean))
28
+ mult = mult * mult
29
+
30
+ y_obs_sq = sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
31
+ y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean) )
32
+
33
+ return mult / float(y_obs_sq * y_pred_sq)
34
+
35
+
36
+ def get_k(y_obs,y_pred):
37
+ y_obs = np.array(y_obs)
38
+ y_pred = np.array(y_pred)
39
+
40
+ return sum(y_obs*y_pred) / float(sum(y_pred*y_pred))
41
+
42
+
43
+ def squared_error_zero(y_obs,y_pred):
44
+ k = get_k(y_obs,y_pred)
45
+
46
+ y_obs = np.array(y_obs)
47
+ y_pred = np.array(y_pred)
48
+ y_obs_mean = [np.mean(y_obs) for y in y_obs]
49
+ upp = sum((y_obs - (k*y_pred)) * (y_obs - (k* y_pred)))
50
+ down= sum((y_obs - y_obs_mean)*(y_obs - y_obs_mean))
51
+
52
+ return 1 - (upp / float(down))
53
+
54
+
55
+ def get_rm2(ys_orig,ys_line):
56
+ r2 = r_squared_error(ys_orig, ys_line)
57
+ r02 = squared_error_zero(ys_orig, ys_line)
58
+
59
+ return r2 * (1 - np.sqrt(np.absolute((r2*r2)-(r02*r02))))
tool/dap/util/load_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tdc.multi_pred import DTI
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ if __name__ == '__main__':
7
+ bindingDB_data = DTI(name = 'BindingDB_Kd')
8
+ davis_data = DTI(name = 'DAVIS')
9
+
10
+ bindingDB_data.harmonize_affinities(mode = 'max_affinity')
11
+
12
+ bindingDB_data.convert_to_log(form = 'binding')
13
+ davis_data.convert_to_log(form = 'binding')
14
+
15
+ split_bindingDB = bindingDB_data.get_split()
16
+ split_davis = davis_data.get_split()
17
+
18
+ dataset_list = ["train", "valid", "test"]
19
+ for dataset_type in dataset_list:
20
+ df_bindingDB = pd.DataFrame(split_bindingDB[dataset_type])
21
+ df_davis = pd.DataFrame(split_davis[dataset_type])
22
+
23
+ df_bindingDB.to_csv(f"../dataset_kd/bindingDB_{dataset_type}.csv", index=False)
24
+ df_davis.to_csv(f"../dataset_kd/davis_{dataset_type}.csv", index=False)
25
+
26
+
27
+ Y_bindingDB = np.array(df_bindingDB.Y)
28
+ Y_davis = np.array(df_davis.Y)
29
+
30
+ Y_davis_log = [np.log10(Y_davis)]
31
+
32
+
tool/dap/util/make_external_validation.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+
5
+ if __name__ == '__main__':
6
+ smiles = pd.read_csv("../dataset/external_smiles.csv")
7
+ ass = pd.read_csv("../dataset/external_aas.csv")
8
+
9
+ smiles_data = list(np.array(smiles['smiles']))
10
+ smiles_label = list(np.array(smiles['label'].tolist()))
11
+ smiles_label = [x.split() for x in smiles_label]
12
+
13
+ ass_data = list(np.array(ass['aas']))
14
+ cyp_type = list(np.array(ass['CYP_type']))
15
+
16
+ external_dataset = []
17
+ for smiles_idx in range(0, len(smiles_data)):
18
+ for ass_idx in range(0, len(ass_data)):
19
+
20
+ external_data = [smiles_data[smiles_idx], ass_data[ass_idx], cyp_type[ass_idx]]
21
+ external_dataset.append(external_data)
22
+
23
+ df = pd.DataFrame(external_dataset, columns=['smiles', 'aas', 'CYP_type'])
24
+ df.to_csv('../dataset/external_dataset.csv', index=False)
25
+
26
+
27
+ print(smiles['smiles'][0])
28
+ print(ass['CYP_type'][0])
tool/dap/util/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, copy
2
+ from easydict import EasyDict
3
+
4
+ import torch.nn as nn
5
+
6
+ class DictX(dict):
7
+ def __getattr__(self, key):
8
+ try:
9
+ return self[key]
10
+ except KeyError as k:
11
+ raise AttributeError(k)
12
+
13
+ def __setattr__(self, key, value):
14
+ self[key] = value
15
+
16
+ def __delattr__(self, key):
17
+ try:
18
+ del self[key]
19
+ except KeyError as k:
20
+ raise AttributeError(k)
21
+
22
+ def __repr__(self):
23
+ return '<DictX ' + dict.__repr__(self) + '>'
24
+
25
+
26
+ def load_hparams(file_path):
27
+ hparams = EasyDict()
28
+ with open(file_path, 'r') as f:
29
+ hparams = json.load(f)
30
+ return hparams
31
+
32
+
33
+ def deleteEncodingLayers(model, num_layers_to_keep): # must pass in the full bert model
34
+ oldModuleList = model.encoder.layer
35
+ newModuleList = nn.ModuleList()
36
+
37
+ # Now iterate over all layers, only keepign only the relevant layers.
38
+ for i in range(num_layers_to_keep):
39
+ newModuleList.append(oldModuleList[i])
40
+
41
+ # create a copy of the model, modify it with the new list, and return
42
+ copyOfModel = copy.deepcopy(model)
43
+ copyOfModel.encoder.layer = newModuleList
44
+
45
+ return copyOfModel
tool/dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
tool/deepacceptor/RF.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Sep 4 10:38:59 2023
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+
9
+ from sklearn.metrics import confusion_matrix
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from rdkit.Chem import AllChem
13
+ from sklearn.datasets import make_blobs
14
+ import json
15
+ import numpy as np
16
+ import math
17
+
18
+ from scipy import sparse
19
+ from sklearn.metrics import median_absolute_error,r2_score, mean_absolute_error,mean_squared_error
20
+ import pickle
21
+ from tqdm import tqdm
22
+
23
+ import pandas as pd
24
+ import matplotlib.pyplot as plt
25
+ from rdkit import Chem
26
+
27
+ from sklearn.ensemble import RandomForestRegressor
28
+
29
+
30
+ def split_string(string):
31
+ result = []
32
+
33
+ for char in string:
34
+
35
+ result.append(char)
36
+
37
+ return result
38
+ def main(sm):
39
+
40
+ inchis = list([sm])
41
+ rts = list([0])
42
+
43
+ smiles, targets,features = [], [],[]
44
+ for i, inc in enumerate((inchis)):
45
+ mol = Chem.MolFromSmiles(inc)
46
+ if mol is None:
47
+ continue
48
+ else:
49
+ smi =AllChem. GetMorganFingerprintAsBitVect(mol,3,2048)
50
+ smi = smi.ToBitString()
51
+ a = split_string(smi)
52
+ a = np.array(a)
53
+ #smi = Chem.MolToSmiles(mol)
54
+ features.append(a)
55
+ targets.append(rts[i])
56
+
57
+
58
+
59
+ features = np.asarray(features)
60
+ targets = np.asarray(targets)
61
+ X_test=features
62
+ Y_test=targets
63
+ n_features=10
64
+
65
+ model = RandomForestRegressor(n_estimators=500)
66
+
67
+ load_model = pickle.load(open(r"tool\deepacceptor\deepacceptor.pkl","rb"))
68
+ Y_predict = load_model.predict(X_test)
69
+
70
+ return Y_predict
tool/deepacceptor/deepacceptor.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11753f5d925de9fdf0ff0afc05204bcb54ad26753f13e02e63696fa3c65e8029
3
+ size 28084161
tool/deepacceptor/dict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [" ", "C", "1", "=", "(", "2", "F", ")", "3", "4", "5", "#", "N", "S", "/", "\\", "O", "6", "7", "8", "9", "%", "0", "[", "Se", "]", "Cl", "Br", "B", ".", "P", "I", "@", "H"]
tool/deepdonor/pm.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:155cb847cef95d069044c425c409ca8daff368bd2f3310f43965b6c65a2914e2
3
+ size 8220594
tool/deepdonor/sm.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b27366195bfd2fcb74dbe20ccc1243e6297ee1f5c272947613f875141fdceb2
3
+ size 31982999
tool/graphconverter.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Nov 7 15:38:35 2024
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+ from DECIMER import predict_SMILES
9
+ from langchain.tools import BaseTool
10
+
11
+ class graphconverter(BaseTool):
12
+ name: str = "graphconverter"
13
+ description: str = (
14
+ "Input graph path , returns SMILES."
15
+ "It was used to convert graph/figure/image containing molecule to SMILES"
16
+ )
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def _run(self, paths: str) -> str:
21
+
22
+ try:
23
+ SMILES = predict_SMILES(paths)
24
+ except:
25
+ return 'Please recheck the graph path'
26
+ return SMILES
27
+
28
+ async def _arun(self, smiles: str) -> str:
29
+ """Use the tool asynchronously."""
30
+ raise NotImplementedError()
31
+
32
+
33
+
tool/orbital.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Wed Oct 30 09:14:55 2024
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+
9
+ from sklearn.metrics import confusion_matrix
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from rdkit.Chem import AllChem
13
+ from sklearn.datasets import make_blobs
14
+ import json
15
+ import numpy as np
16
+ import math
17
+
18
+ from scipy import sparse
19
+ from sklearn.metrics import median_absolute_error,r2_score, mean_absolute_error,mean_squared_error
20
+ from langchain.tools import BaseTool
21
+ import pandas as pd
22
+ import matplotlib.pyplot as plt
23
+ from rdkit import Chem
24
+ import pickle
25
+ from sklearn.ensemble import RandomForestRegressor
26
+
27
+
28
+ def split_string(string):
29
+
30
+ result = []
31
+
32
+ for char in string:
33
+
34
+ result.append(char)
35
+
36
+ return result
37
+
38
+ def main(sm):
39
+
40
+
41
+ inchis = list([sm])
42
+ rts = list([0])
43
+
44
+ smiles, targets,features = [], [],[]
45
+ for i, inc in enumerate(inchis):
46
+ mol = Chem.MolFromSmiles(inc)
47
+ if mol is None:
48
+ continue
49
+ else:
50
+ smi =AllChem. GetMorganFingerprintAsBitVect(mol,1024)
51
+ smi = smi.ToBitString()
52
+ a = split_string(smi)
53
+ a = np.array(a)
54
+ #smi = Chem.MolToSmiles(mol)
55
+ features.append(a)
56
+ targets.append(rts[i])
57
+
58
+
59
+
60
+ features = np.asarray(features)
61
+ targets = np.asarray(targets)
62
+ X_test=features
63
+ Y_test=targets
64
+ n_features=10
65
+
66
+ model = RandomForestRegressor(n_estimators=100)
67
+ load_homo = pickle.load(open(r"tool/orbital/homo.pkl", 'rb'))
68
+ load_lumo = pickle.load(open(r"tool/orbital/lumo.pkl", 'rb'))
69
+ # model = load_model('C:/Users/sunjinyu/Desktop/FingerID Reference/drug-likeness/CNN/single_model.h5')
70
+ Y_homo= load_homo.predict(X_test)
71
+ Y_lumo = load_lumo.predict(X_test)
72
+ homo = float(Y_homo)
73
+ lumo = float(Y_lumo)
74
+ return homo, lumo
75
+
76
+ class homolumo_predictor(BaseTool):
77
+ name: str = "homolumo_predictor"
78
+ description: str = (
79
+ "Input SMILES , returns the HOMO/LUMO (Highest Occupied Molecular Orbital (HOMO) \
80
+ and Lowest Unoccupied Molecular Orbital)."
81
+ )
82
+ def __init__(self):
83
+ super().__init__()
84
+ def _run(self, smiles: str) -> str:
85
+ mol = Chem.MolFromSmiles(smiles)
86
+ if mol is None:
87
+ return "Invalid SMILES string"
88
+ Y_homo, Y_lumo = main( str(smiles) )
89
+ return f"The HOMO is predicted to be {'{:.2f}'.format(Y_homo)} eV , the LUMO is predicted to be {'{:.2f}'.format(Y_lumo)} eV"
90
+
91
+ async def _arun(self, smiles: str) -> str:
92
+ """Use the tool asynchronously."""
93
+ raise NotImplementedError()
94
+
tool/pdfreader.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Dec 30 22:20:13 2024
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+ from langchain.chains import LLMChain, SimpleSequentialChain, RetrievalQA, ConversationalRetrievalChain
8
+
9
+ from langchain import PromptTemplate
10
+
11
+ from langchain.tools import BaseTool
12
+
13
+ from langchain_core.messages import HumanMessage, SystemMessage
14
+ from langchain.base_language import BaseLanguageModel
15
+ from langchain.text_splitter import CharacterTextSplitter
16
+
17
+
18
+ from langchain_community.document_loaders import PyPDFLoader
19
+ from langchain_community.vectorstores import FAISS
20
+ from langchain_openai import ChatOpenAI
21
+ from langchain_openai import OpenAIEmbeddings
22
+
23
+ template = """
24
+
25
+ You are an expert chemist and your task is to respond to the question or
26
+ solve the problem to the best of your ability. You need to answer in as much detail as possible.
27
+ You can only respond with a single "Final Answer" format.
28
+ Use the following pieces of context to answer the question at the end.
29
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
30
+ <context>
31
+ {context}
32
+ </context>
33
+
34
+ Question: {question}
35
+ Answer:
36
+
37
+ """
38
+
39
+ class pdfreader(BaseTool):
40
+ name: str = "pdfreader"
41
+ description: str = (
42
+
43
+ "Used to read papers, summarize papers, Q&A based on papers, literature or publication"
44
+ "Input query , return the response"
45
+ )
46
+
47
+ llm: BaseLanguageModel = None
48
+ path : str = None
49
+ return_direct: bool = True
50
+ def __init__(self, path: str = None):
51
+ super().__init__( )
52
+ self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
53
+ base_url="https://www.dmxapi.com/v1")
54
+ self.path = path
55
+ # api keys
56
+
57
+ def _run(self, query ) -> str:
58
+
59
+ loader = PyPDFLoader(self.path)
60
+ documents = loader.load()
61
+
62
+ text_splitter = CharacterTextSplitter(chunk_size=6000, chunk_overlap=1000)
63
+ docs = text_splitter.split_documents(documents)
64
+ embeddings = OpenAIEmbeddings(api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
65
+ base_url="https://www.dmxapi.com/v1")
66
+
67
+
68
+ vectorstore = FAISS.from_documents(docs, embeddings)
69
+ prompt = PromptTemplate(template=template, input_variables=[ "question"])
70
+ qa_chain = RetrievalQA.from_chain_type(
71
+ llm= self.llm,
72
+ chain_type="stuff",
73
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 2}),
74
+ return_source_documents=True,
75
+ chain_type_kwargs={"prompt": prompt},
76
+ )
77
+
78
+ result = qa_chain.invoke(query)
79
+ return result['result']
80
+
81
+
82
+ async def _arun(self, query) -> str:
83
+ """Use the tool asynchronously."""
84
+ raise NotImplementedError("this tool does not support async")
85
+
86
+
tool/property.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Sep 5 21:42:51 2024
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+ from langchain.tools import BaseTool
9
+ from rdkit import Chem
10
+ from rdkit.Chem import rdMolDescriptors
11
+ from rdkit.Chem import Descriptors
12
+ from utils import *
13
+ from rdkit.Chem import RDConfig
14
+ from rdkit.ML.Descriptors import MoleculeDescriptors
15
+
16
+ from rdkit.Contrib.SA_Score import sascorer
17
+
18
+
19
+ class MolSimilarity(BaseTool):
20
+ name: str = "MolSimilarity"
21
+ description: str = (
22
+ "Input two molecule SMILES (separated by '.'), returns Tanimoto similarity."
23
+ )
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def _run(self, smiles_pair: str) -> str:
29
+ smi_list = smiles_pair.split(".")
30
+ if len(smi_list) != 2:
31
+ return "Input error, please input two smiles strings separated by '.'"
32
+ else:
33
+ smiles1, smiles2 = smi_list
34
+
35
+ similarity = tanimoto(smiles1, smiles2)
36
+
37
+ if isinstance(similarity, str):
38
+ return similarity
39
+
40
+ if similarity == 1:
41
+ return "Error: Input Molecules Are Identical"
42
+ else:
43
+
44
+ message = f"The Tanimoto similarity between {smiles1} and {smiles2} is {round(similarity, 4)}"
45
+ return message
46
+
47
+ async def _arun(self, smiles_pair: str) -> str:
48
+ """Use the tool asynchronously."""
49
+ raise NotImplementedError()
50
+
51
+
52
+ class SMILES2Weight(BaseTool):
53
+ name: str = "SMILES2Weight"
54
+ description: str = "Input SMILES, returns molecular weight."
55
+
56
+ def __init__(
57
+ self,
58
+ ):
59
+ super().__init__()
60
+
61
+ def _run(self, smiles: str) -> str:
62
+ mol = Chem.MolFromSmiles(smiles)
63
+ if mol is None:
64
+ return "Invalid SMILES string"
65
+ mol_weight = rdMolDescriptors.CalcExactMolWt(mol)
66
+ return mol_weight
67
+
68
+ async def _arun(self, smiles: str) -> str:
69
+ """Use the tool asynchronously."""
70
+ raise NotImplementedError()
71
+
72
+ class SMILES2LogP(BaseTool):
73
+ name: str = "SMILES2LogP"
74
+ description: str = "Input SMILES, returns molecular LogP."
75
+
76
+ def __init__(
77
+ self,
78
+ ):
79
+ super().__init__()
80
+
81
+ def _run(self, smiles: str) -> str:
82
+ mol = Chem.MolFromSmiles(smiles)
83
+ if mol is None:
84
+ return "Invalid SMILES string"
85
+ MolLogP = Descriptors.MolLogP(mol)
86
+ return MolLogP
87
+
88
+ async def _arun(self, smiles: str) -> str:
89
+ """Use the tool asynchronously."""
90
+ raise NotImplementedError()
91
+
92
+ class SMILES2SAScore(BaseTool):
93
+ name: str = "SMILES2SAScore"
94
+ description: str = "Input SMILES, returns synthetic accessibility score to evaluate the difficulty of molecular synthesis."
95
+
96
+ def __init__(
97
+ self,
98
+ ):
99
+ super().__init__()
100
+
101
+ def _run(self, smiles: str) -> str:
102
+ mol = Chem.MolFromSmiles(smiles)
103
+ if mol is None:
104
+ return "Invalid SMILES string"
105
+ SAScore = sascorer.calculateScore(mol)
106
+ return f"This SAScore of the molecule is {SAScore}."
107
+
108
+ async def _arun(self, smiles: str) -> str:
109
+ """Use the tool asynchronously."""
110
+ raise NotImplementedError()
111
+
112
+ class SMILES2Properties(BaseTool):
113
+ name: str = "SMILES2Properties"
114
+ description: str = "Input SMILES, returns basic physical and chemical properties."
115
+
116
+ def __init__(
117
+ self,
118
+ ):
119
+ super().__init__()
120
+
121
+ def _run(self, smiles: str) -> str:
122
+ mol = Chem.MolFromSmiles(smiles)
123
+ if mol is None:
124
+ return "Invalid SMILES string"
125
+ SAScore = sascorer.calculateScore(mol)
126
+ des_list = ['MolWt','NOCount', 'NumHAcceptors', 'NumHDonors', 'MolLogP', 'NumRotatableBonds','RingCount','NumAromaticRings','TPSA']
127
+ calculator = MoleculeDescriptors.MolecularDescriptorCalculator(des_list)
128
+ results = calculator.CalcDescriptors(mol)
129
+
130
+
131
+ return f"SAScore: {'{:.2f}'.format(SAScore)}; molecular weight: {'{:.2f}'.format(results[0])}; number of Nitrogens and Oxygens: {results[1]}; number of Hydrogen Bond Acceptors: {results[2]}; number of Hydrogen Bond Donors:{results[3]}; LogP:{'{:.2f}'.format(results[4])}; number of Rotatable Bonds: {results[5]}; Ring count: {results[6]}; number of aromatic rings: {results[7]}; TPSA: {'{:.2f}'.format(results[8])}."
132
+
133
+ async def _arun(self, smiles: str) -> str:
134
+ """Use the tool asynchronously."""
135
+ raise NotImplementedError()
136
+
137
+ class FuncGroups(BaseTool):
138
+ name: str = "FunctionalGroups"
139
+ description: str = "Input SMILES, return list of functional groups in the molecule."
140
+ dict_fgs: dict = None
141
+
142
+ def __init__(
143
+ self,
144
+ ):
145
+ super().__init__()
146
+
147
+ # List obtained from https://github.com/rdkit/rdkit/blob/master/Data/FunctionalGroups.txt
148
+ self.dict_fgs = {
149
+ "furan": "o1cccc1",
150
+ "aldehydes": " [CX3H1](=O)[#6]",
151
+ "esters": " [#6][CX3](=O)[OX2H0][#6]",
152
+ "ketones": " [#6][CX3](=O)[#6]",
153
+ "amides": " C(=O)-N",
154
+ "thiol groups": " [SH]",
155
+ "alcohol groups": " [OH]",
156
+ "methylamide": "*-[N;D2]-[C;D3](=O)-[C;D1;H3]",
157
+ "carboxylic acids": "*-C(=O)[O;D1]",
158
+ "carbonyl methylester": "*-C(=O)[O;D2]-[C;D1;H3]",
159
+ "terminal aldehyde": "*-C(=O)-[C;D1]",
160
+ "amide": "*-C(=O)-[N;D1]",
161
+ "carbonyl methyl": "*-C(=O)-[C;D1;H3]",
162
+ "isocyanate": "*-[N;D2]=[C;D2]=[O;D1]",
163
+ "isothiocyanate": "*-[N;D2]=[C;D2]=[S;D1]",
164
+ "nitro": "*-[N;D3](=[O;D1])[O;D1]",
165
+ "nitroso": "*-[N;R0]=[O;D1]",
166
+ "oximes": "*=[N;R0]-[O;D1]",
167
+ "Imines": "*-[N;R0]=[C;D1;H2]",
168
+ "terminal azo": "*-[N;D2]=[N;D2]-[C;D1;H3]",
169
+ "hydrazines": "*-[N;D2]=[N;D1]",
170
+ "diazo": "*-[N;D2]#[N;D1]",
171
+ "cyano": "*-[C;D2]#[N;D1]",
172
+ "primary sulfonamide": "*-[S;D4](=[O;D1])(=[O;D1])-[N;D1]",
173
+ "methyl sulfonamide": "*-[N;D2]-[S;D4](=[O;D1])(=[O;D1])-[C;D1;H3]",
174
+ "sulfonic acid": "*-[S;D4](=O)(=O)-[O;D1]",
175
+ "methyl ester sulfonyl": "*-[S;D4](=O)(=O)-[O;D2]-[C;D1;H3]",
176
+ "methyl sulfonyl": "*-[S;D4](=O)(=O)-[C;D1;H3]",
177
+ "sulfonyl chloride": "*-[S;D4](=O)(=O)-[Cl]",
178
+ "methyl sulfinyl": "*-[S;D3](=O)-[C;D1]",
179
+ "methyl thio": "*-[S;D2]-[C;D1;H3]",
180
+ "thiols": "*-[S;D1]",
181
+ "thio carbonyls": "*=[S;D1]",
182
+ "halogens": "*-[#9,#17,#35,#53]",
183
+ "t-butyl": "*-[C;D4]([C;D1])([C;D1])-[C;D1]",
184
+ "tri fluoromethyl": "*-[C;D4](F)(F)F",
185
+ "acetylenes": "*-[C;D2]#[C;D1;H]",
186
+ "cyclopropyl": "*-[C;D3]1-[C;D2]-[C;D2]1",
187
+ "ethoxy": "*-[O;D2]-[C;D2]-[C;D1;H3]",
188
+ "methoxy": "*-[O;D2]-[C;D1;H3]",
189
+ "side-chain hydroxyls": "*-[O;D1]",
190
+ "ketones": "*=[O;D1]",
191
+ "primary amines": "*-[N;D1]",
192
+ "nitriles": "*#[N;D1]",
193
+ }
194
+
195
+ def _is_fg_in_mol(self, mol, fg):
196
+ fgmol = Chem.MolFromSmarts(fg)
197
+ mol = Chem.MolFromSmiles(mol.strip())
198
+ return len(Chem.Mol.GetSubstructMatches(mol, fgmol, uniquify=True)) > 0
199
+
200
+ def _run(self, smiles: str) -> str:
201
+ """
202
+ Input a molecule SMILES or name.
203
+ Returns a list of functional groups identified by their common name (in natural language).
204
+ """
205
+ try:
206
+ fgs_in_molec = [
207
+ name
208
+ for name, fg in self.dict_fgs.items()
209
+ if self._is_fg_in_mol(smiles, fg)
210
+ ]
211
+ if len(fgs_in_molec) > 1:
212
+ return f"This molecule contains {', '.join(fgs_in_molec[:-1])}, and {fgs_in_molec[-1]}."
213
+ else:
214
+ return f"This molecule contains {fgs_in_molec[0]}."
215
+ except:
216
+ return "Wrong argument. Please input a valid molecular SMILES."
217
+
218
+ async def _arun(self, smiles: str) -> str:
219
+ """Use the tool asynchronously."""
220
+ raise NotImplementedError()
tool/rag.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Feb 2 20:31:22 2025
4
+
5
+ @author: BM109X32G-10GPU-02
6
+ """
7
+
8
+
9
+ from langchain.tools import BaseTool
10
+
11
+ from langchain.prompts.chat import (
12
+ ChatPromptTemplate,
13
+ HumanMessagePromptTemplate,
14
+ SystemMessagePromptTemplate,
15
+ )
16
+
17
+ from langchain import PromptTemplate
18
+ from langchain import HuggingFacePipeline
19
+
20
+ from langchain.base_language import BaseLanguageModel
21
+
22
+
23
+
24
+ from langchain.chains import RetrievalQA
25
+
26
+ from langchain_community.document_loaders import PyPDFLoader
27
+
28
+ from langchain_openai import ChatOpenAI
29
+ from langchain_community.vectorstores import FAISS
30
+ from torch import cuda, bfloat16
31
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
32
+ from langchain_openai import OpenAIEmbeddings
33
+ embeddings = OpenAIEmbeddings(api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
34
+ base_url="https://www.dmxapi.com/v1")
35
+ vectorstore=FAISS.load_local(r"J:\libray\osc\tool\rag", embeddings,allow_dangerous_deserialization =True)
36
+
37
+
38
+ template = """
39
+
40
+ You are an expert chemist and your task is to respond to the question or
41
+ solve the problem to the best of your ability.You can only respond with a single "Final Answer" format.
42
+ You need to list the key points and explain them in detail and accurately
43
+ Use the following pieces of context to answer the question at the end.
44
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
45
+ <context>
46
+ {context}
47
+ </context>
48
+
49
+ Question: {question}
50
+ Answer:
51
+
52
+ """
53
+
54
+
55
+ class rag(BaseTool):
56
+ name: str = "RAG"
57
+ description: str= (
58
+ "Useful to answer questions that require technical "
59
+
60
+ "Provide specialized knowledge information for solving Q&A questions"
61
+ "Input query , return the response"
62
+
63
+ )
64
+
65
+ llm: BaseLanguageModel = None
66
+ path : str = None
67
+
68
+ def __init__(self, path: str = None):
69
+ super().__init__( )
70
+ self.llm = ChatOpenAI(model="gpt-4o-2024-11-20",api_key='sk-itPrztYm9F6XZZpsBMJB9O7Vq0pYUABVVBSoThuBxEGTnDik',
71
+ base_url="https://www.dmxapi.com/v1")
72
+ self.path = path
73
+ # api keys
74
+
75
+ def _run(self, query ) -> str:
76
+
77
+ prompt = PromptTemplate(template=template, input_variables=[ "question"])
78
+ qa_chain = RetrievalQA.from_chain_type(
79
+ llm=self.llm,
80
+ chain_type="stuff",
81
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 5}),
82
+ return_source_documents=False,
83
+
84
+ chain_type_kwargs={"prompt": prompt},
85
+ )
86
+ chat_history = []
87
+
88
+
89
+ result = qa_chain.invoke(query)
90
+ return result['result']
91
+
92
+
93
+ async def _arun(self, query) -> str:
94
+ """Use the tool asynchronously."""
95
+ raise NotImplementedError("this tool does not support async")
96
+
97
+
98
+
99
+
100
+
101
+
tool/rag/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50ab8fd7a0c8d9dd62ebb1592b542b2d2fb2730ce21b761d2b36e8b5087743cf
3
+ size 6942765
tool/rag/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:908a85200a6d6e140c7d112f5c1b6b73376b42b2f85c3b3786f2050d6f1070d9
3
+ size 5545448
tool/search.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import langchain
5
+
6
+ import paperqa
7
+ import paperscraper
8
+ from langchain_community.utilities import SerpAPIWrapper
9
+ from langchain.base_language import BaseLanguageModel
10
+ from langchain.tools import BaseTool
11
+ from langchain_openai import OpenAIEmbeddings
12
+ from pypdf.errors import PdfReadError
13
+ from rdkit import Chem, DataStructs
14
+ from rdkit.Chem import AllChem
15
+
16
+ def is_smiles(text):
17
+ try:
18
+ m = Chem.MolFromSmiles(text, sanitize=False)
19
+ if m is None:
20
+ return False
21
+ return True
22
+ except:
23
+ return False
24
+
25
+
26
+
27
+ def is_multiple_smiles(text):
28
+ if is_smiles(text):
29
+ return "." in text
30
+ return False
31
+
32
+
33
+ def split_smiles(text):
34
+ return text.split(".")
35
+
36
+ def paper_scrap(search: str, pdir: str = "query", semantic_scholar_api_key: str = None) -> dict:
37
+ try:
38
+ return paperscraper.search_papers(
39
+ search,
40
+ pdir=pdir,
41
+ semantic_scholar_api_key=semantic_scholar_api_key,
42
+ )
43
+ except KeyError:
44
+ return {}
45
+
46
+
47
+ def paper_search(llm, query, semantic_scholar_api_key=None):
48
+ prompt = langchain.prompts.PromptTemplate(
49
+ input_variables=["question"],
50
+ template="""
51
+ I would like to find scholarly papers to answer
52
+ this question: {question}. Your response must be at
53
+ most 10 words long.
54
+ 'A search query that would bring up papers that can answer
55
+ this question would be: '""",
56
+ )
57
+
58
+ query_chain = langchain.chains.llm.LLMChain(llm=llm, prompt=prompt)
59
+ if not os.path.isdir("./query"): # todo: move to ckpt
60
+ os.mkdir("query/")
61
+ search = query_chain.run(query)
62
+ print("\nSearch:", search)
63
+ papers = paper_scrap(search, pdir=f"query/{re.sub(' ', '', search)}", semantic_scholar_api_key=semantic_scholar_api_key)
64
+ return papers
65
+
66
+
67
+ def scholar2result_llm(llm, query, k=5, max_sources=2, openai_api_key=None, semantic_scholar_api_key=None):
68
+ """Useful to answer questions that require
69
+ technical knowledge. Ask a specific question."""
70
+ papers = paper_search(llm, query, semantic_scholar_api_key=semantic_scholar_api_key)
71
+ if len(papers) == 0:
72
+ return "Not enough papers found"
73
+ docs = paperqa.Docs(
74
+ llm=llm,
75
+ summary_llm=llm,
76
+ embeddings=OpenAIEmbeddings(openai_api_key=openai_api_key),
77
+ )
78
+ not_loaded = 0
79
+ for path, data in papers.items():
80
+ try:
81
+ docs.add(path, data["citation"])
82
+ except (ValueError, FileNotFoundError, PdfReadError):
83
+ not_loaded += 1
84
+
85
+ if not_loaded > 0:
86
+ print(f"\nFound {len(papers.items())} papers but couldn't load {not_loaded}.")
87
+ else:
88
+ print(f"\nFound {len(papers.items())} papers and loaded all of them.")
89
+
90
+ answer = docs.query(query, k=k, max_sources=max_sources).formatted_answer
91
+ return answer
92
+
93
+
94
+ class Scholar2ResultLLM(BaseTool):
95
+ name : str = "LiteratureSearch"
96
+ description: str = (
97
+ "Useful to answer questions that require technical "
98
+ "knowledge. Ask a specific question."
99
+ )
100
+ llm: BaseLanguageModel = None
101
+ openai_api_key: str = None
102
+ semantic_scholar_api_key: str = None
103
+
104
+
105
+ def __init__(self, llm, openai_api_key, semantic_scholar_api_key):
106
+ super().__init__()
107
+ self.llm = llm
108
+ # api keys
109
+ self.openai_api_key = openai_api_key
110
+ self.semantic_scholar_api_key = semantic_scholar_api_key
111
+
112
+ def _run(self, query) -> str:
113
+ return scholar2result_llm(
114
+ self.llm,
115
+ query,
116
+ openai_api_key=self.openai_api_key,
117
+ semantic_scholar_api_key=self.semantic_scholar_api_key
118
+ )
119
+
120
+ async def _arun(self, query) -> str:
121
+ """Use the tool asynchronously."""
122
+ raise NotImplementedError("this tool does not support async")
123
+
124
+
125
+ def web_search(keywords, search_engine="google"):
126
+ try:
127
+ return SerpAPIWrapper(
128
+ serpapi_api_key='3795acda6a74ea15033d34b54eac82982b26f559147d9cf04aca4bfca91c3e9d', search_engine=search_engine
129
+ ).run(keywords)
130
+ except:
131
+ return "No results, try another search"
132
+
133
+
134
+ class WebSearch(BaseTool):
135
+ name: str = "WebSearch"
136
+ description: str = (
137
+ "Input a specific question, returns an answer from web search. "
138
+ "Give more detailed information and use more general features to formulate your questions."
139
+ )
140
+ serp_api_key: str = None
141
+
142
+ def __init__(self, serp_api_key: str = None):
143
+ super().__init__()
144
+ self.serp_api_key = serp_api_key
145
+
146
+ def _run(self, query: str) -> str:
147
+ if not self.serp_api_key:
148
+ return (
149
+ "No SerpAPI key found. This tool may not be used without a SerpAPI key."
150
+ )
151
+ return web_search(query)
152
+
153
+ async def _arun(self, query: str) -> str:
154
+ raise NotImplementedError("Async not implemented")
155
+
156
+