cafierom commited on
Commit
7c2a63c
·
verified ·
1 Parent(s): 950b6a7

Upload ProteinAgent_HFS.py

Browse files
Files changed (1) hide show
  1. ProteinAgent_HFS.py +814 -0
ProteinAgent_HFS.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Annotated, TypedDict, Literal
3
+ from langchain_community.tools import DuckDuckGoSearchRun
4
+ from langchain_core.tools import tool
5
+ from langgraph.prebuilt import ToolNode, tools_condition
6
+ from langgraph.graph import StateGraph, START, END
7
+ from langgraph.graph.message import add_messages
8
+ from langchain_core.messages import SystemMessage, trim_messages, AIMessage, HumanMessage, ToolCall
9
+
10
+ from langchain_huggingface.llms import HuggingFacePipeline
11
+ from langchain_huggingface import ChatHuggingFace
12
+ from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
13
+ from langchain_core.runnables import chain
14
+ from uuid import uuid4
15
+ import re
16
+ import matplotlib.pyplot as plt
17
+
18
+ from rdkit import Chem
19
+ from rdkit.Chem import AllChem, QED
20
+ from rdkit.Chem import Draw
21
+ from rdkit.Chem.Draw import MolsToGridImage
22
+ from rdkit import rdBase
23
+ from rdkit.Chem import rdMolAlign
24
+ import os, re
25
+ from rdkit import RDConfig
26
+ import gradio as gr
27
+ from PIL import Image
28
+
29
+ import numpy as np
30
+ import pandas as pd
31
+ from chembl_webresource_client.new_client import new_client
32
+ from tqdm.auto import tqdm
33
+ import requests
34
+ import spaces
35
+
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+
38
+ hf = HuggingFacePipeline.from_model_id(
39
+ model_id= "microsoft/Phi-4-mini-instruct",
40
+ task="text-generation",
41
+ pipeline_kwargs = {"max_new_tokens": 500, "temperature": 0.4})
42
+
43
+ chat_model = ChatHuggingFace(llm=hf)
44
+
45
+ class State(TypedDict):
46
+ '''
47
+ The state of the agent.
48
+ '''
49
+ messages: Annotated[list, add_messages]
50
+ query_smiles: str
51
+ query_task: str
52
+ query_protein: str
53
+ query_up_id: str
54
+ query_pdb: str
55
+ query_chembl: str
56
+ tool_choice: tuple
57
+ which_tool: int
58
+ props_string: str
59
+ loop_again: str
60
+ #(Literal["lipinski_tool", "substitution_tool", "pharm_feature_tool"],
61
+ # Literal["lipinski_tool", "substitution_tool", "pharm_feature_tool"])
62
+
63
+
64
+ def uniprot_node(state: State) -> State:
65
+ '''
66
+ This tool takes in the user requested protein and searches UNIPROT for matches.
67
+ It returns a string scontaining the protein ID, gene name, organism, and protein name.
68
+
69
+ Args:
70
+ query_protein: the name of the protein to search for.
71
+ Returns:
72
+ protein_string: a string containing the protein ID, gene name, organism, and protein name.
73
+
74
+ '''
75
+ print("UNIPROT tool")
76
+ print('===================================================')
77
+
78
+ protein_name = state["query_protein"]
79
+ current_props_string = state["props_string"]
80
+
81
+ try:
82
+ url = f'https://rest.uniprot.org/uniprotkb/search?query={protein_name}&format=tsv'
83
+ response = requests.get(url).text
84
+
85
+ f = open(f"{protein_name}_uniprot_ids.tsv", "w")
86
+ f.write(response)
87
+ f.close()
88
+
89
+ prot_df = pd.read_csv(f'{protein_name}_uniprot_ids.tsv', sep='\t')
90
+ prot_human_df = prot_df[prot_df['Organism'] == "Homo sapiens (Human)"]
91
+ print(f"Found {len(prot_human_df)} Human proteins out of {len(prot_df)} total proteins")
92
+
93
+ prot_ids = prot_df['Entry'].tolist()
94
+ prot_ids_human = prot_human_df['Entry'].tolist()
95
+
96
+ genes = prot_df['Gene Names'].tolist()
97
+ genes_human = prot_human_df['Gene Names'].tolist()
98
+
99
+ organisms = prot_df['Organism'].tolist()
100
+
101
+ names = prot_df['Protein names'].tolist()
102
+ names_human = prot_human_df['Protein names'].tolist()
103
+
104
+ protein_string = ''
105
+ for id, gene, organism, name in zip(prot_ids, genes, organisms, names):
106
+ protein_string += f'Protein ID: {id}, Gene: {gene}, Organism: {organism}, Name: {name}\n'
107
+
108
+ except:
109
+ protein_string = 'No proteins found'
110
+
111
+ current_props_string += protein_string
112
+ state["props_string"] = current_props_string
113
+ state["which_tool"] += 1
114
+ return state
115
+
116
+ def get_qed(smiles):
117
+ '''
118
+ Helper function to compute QED for a given molecule.
119
+ Args:
120
+ smiles: the input smiles string
121
+ Returns:
122
+ qed: the QED score of the molecule.
123
+ '''
124
+ mol = Chem.MolFromSmiles(smiles)
125
+ qed = Chem.QED.default(mol)
126
+
127
+ return qed
128
+
129
+ def listbioactives_node(state: State) -> State:
130
+ '''
131
+ Accepts a UNIPROT ID and searches for bioactive molecules
132
+ Args:
133
+ up_id: the UNIPROT ID of the protein to search for.
134
+ Returns:
135
+ props_string: the number of bioactive molecules for the given protein
136
+ '''
137
+ print("List bioactives tool")
138
+ print('===================================================')
139
+
140
+ up_id = state["query_up_id"].strip()
141
+ current_props_string = state["props_string"]
142
+
143
+ targets = new_client.target
144
+ bioact = new_client.activity
145
+
146
+ try:
147
+ target_info = targets.get(target_components__accession=up_id).only("target_chembl_id","organism", "pref_name", "target_type")
148
+ target_info = pd.DataFrame.from_records(target_info)
149
+ print(target_info)
150
+ if len(target_info) > 0:
151
+ print(f"Found info for Uniprot ID: {up_id}")
152
+
153
+ chembl_ids = target_info['target_chembl_id'].tolist()
154
+
155
+ chembl_ids = list(set(chembl_ids))
156
+ print(f"Found {len(chembl_ids)} unique ChEMBL IDs")
157
+
158
+ len_all_bioacts = []
159
+ bioact_string = ''
160
+ for chembl_id in chembl_ids:
161
+ bioact_chosen = bioact.filter(target_chembl_id=chembl_id, type="IC50", relation="=").only(
162
+ "molecule_chembl_id",
163
+ "type",
164
+ "standard_units",
165
+ "relation",
166
+ "standard_value",
167
+ )
168
+ len_this_bioacts = len(bioact_chosen)
169
+ len_all_bioacts.append(len_this_bioacts)
170
+ this_bioact_string = f"Lenth of Bioactivities for ChEMBL ID {chembl_id}: {len_this_bioacts}"
171
+
172
+ bioact_string += this_bioact_string + '\n'
173
+ except:
174
+ bioact_string = 'No bioactives found\n'
175
+
176
+ current_props_string += bioact_string
177
+ state["props_string"] = current_props_string
178
+ state["which_tool"] += 1
179
+ return state
180
+
181
+ def getbioactives_node(state: State) -> State:
182
+ '''
183
+ Accepts a Chembl ID and get all bioactives molecule SMILES and IC50s for that ID
184
+ Args:
185
+ chembl_id: the chembl ID to query
186
+ Returns:
187
+ props_string: the bioactive molecule SMILES and IC50s for the chembl ID
188
+ '''
189
+ print("Get bioactives tool")
190
+ print('===================================================')
191
+
192
+ chembl_id = state["query_chembl"].strip()
193
+ current_props_string = state["props_string"]
194
+
195
+ compounds = new_client.molecule
196
+ bioact = new_client.activity
197
+
198
+ bioact_chosen = bioact.filter(target_chembl_id=chembl_id, type="IC50", relation="=").only(
199
+ "molecule_chembl_id",
200
+ "type",
201
+ "standard_units",
202
+ "relation",
203
+ "standard_value",
204
+ )
205
+
206
+ chembl_ids = []
207
+ ic50s = []
208
+ for record in bioact_chosen:
209
+ if record["standard_units"] == 'nM':
210
+ chembl_ids.append(record["molecule_chembl_id"])
211
+ ic50s.append(float(record["standard_value"]))
212
+
213
+ bioact_dict = {'chembl_ids' : chembl_ids, 'IC50s': ic50s}
214
+ bioact_df = pd.DataFrame.from_dict(bioact_dict)
215
+ bioact_df.drop_duplicates(subset=["chembl_ids"], keep= "last")
216
+ print(f"Number of records: {len(bioact_df)}")
217
+ print(bioact_df.shape)
218
+
219
+
220
+ compounds_provider = compounds.filter(molecule_chembl_id__in=bioact_df["chembl_ids"].to_list()).only(
221
+ "molecule_chembl_id",
222
+ "molecule_structures"
223
+ )
224
+
225
+ cids_list = []
226
+ smiles_list = []
227
+
228
+ for record in compounds_provider:
229
+ cid = record['molecule_chembl_id']
230
+ cids_list.append(cid)
231
+
232
+ if record['molecule_structures']:
233
+ if record['molecule_structures']['canonical_smiles']:
234
+ smile = record['molecule_structures']['canonical_smiles']
235
+ else:
236
+ print("No canonical smiles")
237
+ smile = None
238
+ else:
239
+ print('no structures')
240
+ smile = None
241
+ smiles_list.append(smile)
242
+
243
+ new_dict = {'SMILES': smiles_list, 'chembl_ids_2': cids_list}
244
+ new_df = pd.DataFrame.from_dict(new_dict)
245
+
246
+ total_bioact_df = pd.merge(bioact_df, new_df, left_on='chembl_ids', right_on='chembl_ids_2')
247
+ print(f"number of records: {len(total_bioact_df)}")
248
+
249
+ total_bioact_df.drop_duplicates(subset=["chembl_ids"], keep= "last")
250
+ print(f"number of records after removing duplicates: {len(total_bioact_df)}")
251
+
252
+ total_bioact_df.dropna(axis=0, how='any', inplace=True)
253
+ total_bioact_df.drop(["chembl_ids_2"],axis=1,inplace=True)
254
+ print(f"number of records after dropping Null values: {len(total_bioact_df)}")
255
+
256
+ total_bioact_df.sort_values(by=["IC50s"],inplace=True)
257
+
258
+ limit = 50
259
+ if len(total_bioact_df) > limit:
260
+ total_bioact_df = total_bioact_df.iloc[:limit]
261
+
262
+ bioact_string = f'Results for top bioactivity (IC50 value) for molecules in ChEMBL ID: {chembl_id}. \n'
263
+ for smile, ic50 in zip(total_bioact_df['SMILES'], total_bioact_df['IC50s']):
264
+ bioact_string += f'Molecule SMILES: {smile}, IC50 (nM): {ic50}\n'
265
+
266
+ current_props_string += bioact_string
267
+ state["props_string"] = current_props_string
268
+ state["which_tool"] += 1
269
+ return state
270
+
271
+ def get_protein_from_pdb(pdb_id):
272
+ url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
273
+ r = requests.get(url)
274
+ return r.text
275
+
276
+ def one_to_three(one_seq):
277
+ rev_aa_hash = {
278
+ 'A': 'ALA',
279
+ 'R': 'ARG',
280
+ 'N': 'ASN',
281
+ 'D': 'ASP',
282
+ 'C': 'CYS',
283
+ 'Q': 'GLN',
284
+ 'E': 'GLU',
285
+ 'G': 'GLY',
286
+ 'H': 'HIS',
287
+ 'I': 'ILE',
288
+ 'L': 'LEU',
289
+ 'K': 'LYS',
290
+ 'M': 'MET',
291
+ 'F': 'PHE',
292
+ 'P': 'PRO',
293
+ 'S': 'SER',
294
+ 'T': 'THR',
295
+ 'W': 'TRP',
296
+ 'Y': 'TYR',
297
+ 'V': 'VAL'
298
+ }
299
+
300
+ try:
301
+ three_seq = rev_aa_hash[one_seq]
302
+ except:
303
+ three_seq = 'X'
304
+
305
+ return three_seq
306
+
307
+ def three_to_one(three_seq):
308
+ aa_hash = {
309
+ 'ALA': 'A',
310
+ 'ARG': 'R',
311
+ 'ASN': 'N',
312
+ 'ASP': 'D',
313
+ 'CYS': 'C',
314
+ 'GLN': 'Q',
315
+ 'GLU': 'E',
316
+ 'GLY': 'G',
317
+ 'HIS': 'H',
318
+ 'ILE': 'I',
319
+ 'LEU': 'L',
320
+ 'LYS': 'K',
321
+ 'MET': 'M',
322
+ 'PHE': 'F',
323
+ 'PRO': 'P',
324
+ 'SER': 'S',
325
+ 'THR': 'T',
326
+ 'TRP': 'W',
327
+ 'TYR': 'Y',
328
+ 'VAL': 'V'
329
+ }
330
+
331
+ one_seq = []
332
+ for residue in three_seq:
333
+ try:
334
+ one_seq.append(aa_hash[residue])
335
+ except:
336
+ one_seq.append('X')
337
+
338
+ return one_seq
339
+
340
+ def pdb_node(state: State) -> State:
341
+ '''
342
+ Accepts a PDB ID and queires the protein databank for the sequence of the protein, as well as other
343
+ information such as ligands.
344
+
345
+ Args:
346
+ pdb: the PDB ID to query
347
+ Returns:
348
+ props_string: a string of the
349
+ '''
350
+ test_pdb = state["query_pdb"].strip()
351
+ current_props_string = state["props_string"]
352
+
353
+ print(f"pdb tool using {test_pdb}")
354
+ print('===================================================')
355
+
356
+ pdb_str = get_protein_from_pdb(test_pdb)
357
+ chains = {}
358
+ other_molecules = {}
359
+
360
+ #print(pdb_str.split('\n')[0])
361
+ for line in pdb_str.split('\n'):
362
+ parts = line.split()
363
+ try:
364
+ if parts[0] == 'SEQRES':
365
+ if parts[2] not in chains:
366
+ chains[parts[2]] = []
367
+ chains[parts[2]].extend(parts[4:])
368
+ if parts[0] == 'HETNAM':
369
+ j = 1
370
+ if parts[1].strip() in ['2','3','4','5','6','7','8','9']:
371
+ j = 2
372
+ print(parts[j])
373
+ if parts[j] not in other_molecules:
374
+ other_molecules[parts[j]] = []
375
+ other_molecules[parts[j]].extend(parts[2:])
376
+ except:
377
+ print('Blank line')
378
+
379
+ chains_ol = {}
380
+ for chain in chains:
381
+ chains_ol[chain] = three_to_one(chains[chain])
382
+
383
+ props_string = f"Chains in PDB ID {test_pdb}: {', '.join(chains.keys())} \n"
384
+ for chain in chains_ol:
385
+ props_string += f"Chain {chain}: {''.join(chains_ol[chain])} \n"
386
+ print(f"Chain {chain}: {''.join(chains_ol[chain])}")
387
+ props_string += f"Ligands in PDB ID {test_pdb}.\n"
388
+ for mol in other_molecules:
389
+ props_string += f"Molecule {mol}: {''.join(other_molecules[mol])} \n"
390
+
391
+ current_props_string += props_string
392
+ state["props_string"] = current_props_string
393
+ state["which_tool"] += 1
394
+ return state
395
+
396
+ def first_node(state: State) -> State:
397
+ '''
398
+ The first node of the agent. This node receives the input and asks the LLM
399
+ to determine which is the best tool to use to answer the QUERY TASK.
400
+
401
+ Input: the initial prompt from the user. should contain only one of more of the following:
402
+ query_protein: the name of the protein to search for.
403
+ query_up_id: the Uniprot ID of the protein to search for.
404
+ query_chembl: the chembl ID to query
405
+ query_pdb: the PDB ID to query
406
+ query_smiles: the smiles string
407
+ query_task: the query task
408
+ the value should be separated from the name by a ':' and each field should
409
+ be separated from the previous one by a ','.
410
+ All of these values are saved to the state
411
+
412
+ Output: the tool choice
413
+ '''
414
+ query_smiles = None
415
+ state["query_smiles"] = query_smiles
416
+ query_task = None
417
+ state["query_task"] = query_task
418
+ query_protein = None
419
+ state["query_protein"] = query_protein
420
+ query_up_id = None
421
+ state["query_up_id"] = query_up_id
422
+ query_pdb = None
423
+ state["query_pdb"] = query_pdb
424
+ query_chembl = None
425
+ state["query_chembl"] = query_chembl
426
+ props_string = ""
427
+ state["props_string"] = props_string
428
+ state["loop_again"] = None
429
+
430
+ raw_input = state["messages"][-1].content
431
+ parts = raw_input.split(',')
432
+ for part in parts:
433
+ if 'smiles' in part:
434
+ query_smiles = part.split(':')[1]
435
+ if query_smiles.lower() == 'none':
436
+ query_smiles = None
437
+ state["query_smiles"] = query_smiles
438
+ if 'task' in part:
439
+ query_task = part.split(':')[1]
440
+ state["query_task"] = query_task
441
+ if 'protein' in part:
442
+ query_protein = part.split(':')[1]
443
+ if query_protein.lower() == 'none':
444
+ query_protein = None
445
+ state["query_protein"] = query_protein
446
+ if 'up_id' in part:
447
+ query_up_id = part.split(':')[1]
448
+ if query_up_id.lower() == 'none':
449
+ query_up_id = None
450
+ state["query_up_id"] = query_up_id
451
+ if 'pdb' in part:
452
+ query_pdb = part.split(':')[1]
453
+ if query_pdb.lower() == 'none':
454
+ query_pdb = None
455
+ state["query_pdb"] = query_pdb
456
+ if 'chembl' in part:
457
+ query_chembl = part.split(':')[1]
458
+ if query_chembl.lower() == 'none':
459
+ query_chembl = None
460
+ state["query_chembl"] = query_chembl
461
+
462
+ prompt = f'For the QUERY_TASK given below, determine if one or two of the tools descibed below \
463
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
464
+ are required, reply with both tool names separated by a comma and followed by "#". \
465
+ If the tools cannot complete the task, reply with "None #".\n \
466
+ QUERY_TASK: {query_task}.\n \
467
+ Tools: \n \
468
+ uniprot_tool: this tool takes in the user requested protein and searches UNIPROT for matches. \
469
+ It returns a string containing the protein ID, gene name, organism, and protein name.\n \
470
+ list_bioactives_tool: Accepts a given UNIPROT ID and searches for bioactive molecules \n \
471
+ get_bioactives_tool: Accepts a Chembl ID and get all bioactives molecule SMILES and IC50s for that ID\n \
472
+ pdb_tool: Accepts a PDB ID and queires the protein databank for the number of chains in and sequence of the \n \
473
+ protein, as well as other information such as ligands in the structure.\
474
+ '
475
+ res = chat_model.invoke(prompt)
476
+
477
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
478
+ tool_choices = tool_choices.split(',')
479
+
480
+ if len(tool_choices) == 1:
481
+ tool1 = tool_choices[0].strip()
482
+ if tool1.lower() == 'none':
483
+ tool_choice = (None, None)
484
+ else:
485
+ tool_choice = (tool1, None)
486
+ elif len(tool_choices) == 2:
487
+ tool1 = tool_choices[0].strip()
488
+ tool2 = tool_choices[1].strip()
489
+ if tool1.lower() == 'none' and tool2.lower() == 'none':
490
+ tool_choice = (None, None)
491
+ elif tool1.lower() == 'none' and tool2.lower() != 'none':
492
+ tool_choice = (None, tool2)
493
+ elif tool2.lower() == 'none' and tool1.lower() != 'none':
494
+ tool_choice = (tool1, None)
495
+ else:
496
+ tool_choice = (tool1, tool2)
497
+ else:
498
+ tool_choice = (None, None)
499
+
500
+ state["tool_choice"] = tool_choice
501
+ state["which_tool"] = 0
502
+ print(f"The chosen tools are: {tool_choice}")
503
+
504
+ return state
505
+
506
+ def retry_node(state: State) -> State:
507
+ '''
508
+ If the previous loop of the agent does not get enough information from the
509
+ tools to answer the query, this node is called to retry the previous loop.
510
+ Input: the previous loop of the agent.
511
+ Output: the tool choice
512
+ '''
513
+ query_task = state["query_task"]
514
+ query_protein = state["query_protein"]
515
+ query_up_id = state["query_up_id"]
516
+ query_chembl = state["query_chembl"]
517
+ query_pdb = state["query_pdb"]
518
+ query_smiles = state["query_smiles"]
519
+
520
+ prompt = f'You were previously given the QUERY_TASK below, and asked to determine if one \
521
+ or two of the tools described below could complete the task. The tool choices did not succeed. \
522
+ Please re-examine the tool choices and determine if one or two of the tools described below \
523
+ can complete the task. If so, reply with only the tool names followed by "#". If two tools \
524
+ are required, reply with both tool names separated by a comma and followed by "#". \
525
+ If the tools cannot complete the task, reply with "None #".\n \
526
+ The information provided by the user is:\n \
527
+ QUERY_PROTEIN: {query_protein}.\n \
528
+ QUERY_UP_ID: {query_up_id}.\n \
529
+ QUERY_CHEMBL: {query_chembl}.\n \
530
+ QUERY_PDB: {query_pdb}.\n \
531
+ QUERY_SMILES: {query_smiles}.\n \
532
+ The task is: \
533
+ QUERY_TASK: {query_task}.\n \
534
+ Tool options: \n \
535
+ uniprot_tool: this tool takes in the user requested protein and searches UNIPROT for matches. \
536
+ It returns a string containing the protein ID, gene name, organism, and protein name.\n \
537
+ list_bioactives_tool: Accepts a given UNIPROT ID and searches for bioactive molecules \n \
538
+ get_bioactives_tool: Accepts a Chembl ID and get all bioactives molecule SMILES and IC50s for that ID\n \
539
+ pdb_tool: Accepts a PDB ID and queires the protein databank for the number of chains in and sequence of the \
540
+ protein, as well as other information such as ligands in the structure. \n'
541
+
542
+ res = chat_model.invoke(prompt)
543
+
544
+ tool_choices = str(res).split('<|assistant|>')[1].split('#')[0].strip()
545
+ tool_choices = tool_choices.split(',')
546
+
547
+ if len(tool_choices) == 1:
548
+ tool1 = tool_choices[0].strip()
549
+ if tool1.lower() == 'none':
550
+ tool_choice = (None, None)
551
+ else:
552
+ tool_choice = (tool1.strip(), None)
553
+ elif len(tool_choices) > 1:
554
+ tool1 = tool_choices[0].strip()
555
+ tool2 = tool_choices[1].strip()
556
+ if tool1.lower() == 'none' and tool2.lower() == 'none':
557
+ tool_choice = (None, None)
558
+ elif tool1.lower() == 'none' and tool2.lower() != 'none':
559
+ tool_choice = (None, tool2)
560
+ elif tool2.lower() == 'none' and tool1.lower() != 'none':
561
+ tool_choice = (tool1, None)
562
+ else:
563
+ tool_choice = (tool1, tool2)
564
+ else:
565
+ tool_choice = (None, None)
566
+
567
+ state["tool_choice"] = tool_choice
568
+ state["which_tool"] = 0
569
+ print(f"The chosen tools are (Retry): {tool_choice}")
570
+
571
+ return state
572
+
573
+ def loop_node(state: State) -> State:
574
+ '''
575
+ This node accepts the tool returns and decides if it needs to call another
576
+ tool or go on to the parser node.
577
+
578
+ Input: the tool returns.
579
+ Output: the next node to call.
580
+ '''
581
+ return state
582
+
583
+ def parser_node(state: State) -> State:
584
+ '''
585
+ This is the third node in the agent. It receives the output from the tool,
586
+ puts it into a prompt as CONTEXT, and asks the LLM to answer the original
587
+ query.
588
+
589
+ Input: the output from the tool.
590
+ Output: the answer to the original query.
591
+ '''
592
+ props_string = state["props_string"]
593
+ query_task = state["query_task"]
594
+ tool_choice = state["tool_choice"]
595
+
596
+ if type(tool_choice) != tuple and tool_choice == None:
597
+ state["loop_again"] = "finish_gracefully"
598
+ return state
599
+ elif type(tool_choice) == tuple and (tool_choice[0] == None) and (tool_choice[1] == None):
600
+ state["loop_again"] = "finish_gracefully"
601
+ return state
602
+
603
+ prompt = f'Using the CONTEXT below, answer the original query, which \
604
+ was to answer the QUERY_TASK. End your answer with a "#" \
605
+ QUERY_TASK: {query_task}.\n \
606
+ CONTEXT: {props_string}.\n '
607
+
608
+ res = chat_model.invoke(prompt)
609
+ trial_answer = str(res).split('<|assistant|>')[1]
610
+ print('parser 1 ', trial_answer)
611
+ state["messages"] = res
612
+
613
+ check_prompt = f'Determine if the TRIAL ANSWER below answers the original \
614
+ QUERY TASK. If it does, respond with "PROCEED #" . If the TRIAL ANSWER did not \
615
+ answer the QUERY TASK, respond with "LOOP #" \n \
616
+ Only loop again if the TRIAL ANSWER did not answer the QUERY TASK. \
617
+ TRIAL ANSWER: {trial_answer}.\n \
618
+ QUERY_TASK: {query_task}.\n'
619
+
620
+ res = chat_model.invoke(check_prompt)
621
+ print('parser, loop again? ', res)
622
+
623
+ if str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "loop":
624
+ state["loop_again"] = "loop_again"
625
+ return state
626
+ elif str(res).split('<|assistant|>')[1].split('#')[0].strip().lower() == "proceed":
627
+ state["loop_again"] = None
628
+ print('trying to break loop')
629
+ elif "proceed" in str(res).split('<|assistant|>')[1].lower():
630
+ state["loop_again"] = None
631
+ print('trying to break loop')
632
+
633
+ return state
634
+
635
+ def reflect_node(state: State) -> State:
636
+ '''
637
+ This is the fourth node of the agent. It recieves the LLMs previous answer and
638
+ tries to improve it.
639
+
640
+ Input: the LLMs last answer.
641
+ Output: the improved answer.
642
+ '''
643
+ previous_answer = state["messages"][-1].content
644
+ props_string = state["props_string"]
645
+
646
+ prompt = f'Look at the PREVIOUS ANSWER below which you provided and the \
647
+ TOOL RESULTS. Write an improved answer based on the PREVIOUS ANSWER and the \
648
+ TOOL RESULTS by adding additional clarifying and enriching information. End \
649
+ your new answer with a "#" \
650
+ PREVIOUS ANSWER: {previous_answer}.\n \
651
+ TOOL RESULTS: {props_string}. '
652
+
653
+ res = chat_model.invoke(prompt)
654
+ return {"messages": res}
655
+
656
+ def graceful_exit_node(state: State) -> State:
657
+ '''
658
+ Called when the Agent cannot assign any tools for the task
659
+ '''
660
+ props_string = state["props_string"]
661
+ prompt = f'Summarize the information in the CONTEXT, including any useful chemical information. Start your answer with: \
662
+ Here is what I found: \n \
663
+ CONTEXT: {props_string}'
664
+
665
+ res = chat_model.invoke(prompt)
666
+
667
+ return {"messages": res}
668
+
669
+ def get_chemtool(state):
670
+ '''
671
+ '''
672
+ which_tool = state["which_tool"]
673
+ tool_choice = state["tool_choice"]
674
+
675
+ if tool_choice is None or tool_choice == (None, None):
676
+ return None
677
+
678
+ if which_tool == 0 or which_tool == 1:
679
+ current_tool = tool_choice[which_tool]
680
+ if current_tool is None:
681
+ return None
682
+ elif which_tool > 1:
683
+ current_tool = None
684
+
685
+ return current_tool
686
+
687
+ def loop_or_not(state):
688
+ '''
689
+ '''
690
+ print(f"(line 482) Loop? {state['loop_again']}")
691
+ if state["loop_again"] == "loop_again":
692
+ return True
693
+ elif state["loop_again"] == "finish_gracefully":
694
+ return 'lets_get_outta_here'
695
+ else:
696
+ return False
697
+
698
+ def pretty_print(answer):
699
+ final = str(answer['messages'][-1]).split('<|assistant|>')[-1].split('#')[0].strip("n").strip('\\').strip('n').strip('\\')
700
+ for i in range(0,len(final),100):
701
+ print(final[i:i+100])
702
+
703
+ def print_short(answer):
704
+ for i in range(0,len(answer),100):
705
+ print(answer[i:i+100])
706
+
707
+ builder = StateGraph(State)
708
+ builder.add_node("first_node", first_node)
709
+ builder.add_node("retry_node", retry_node)
710
+ builder.add_node("uniprot_node", uniprot_node)
711
+ builder.add_node("listbioactives_node", listbioactives_node)
712
+ builder.add_node("getbioactives_node", getbioactives_node)
713
+ builder.add_node("pdb_node", pdb_node)
714
+
715
+ builder.add_node("loop_node", loop_node)
716
+ builder.add_node("parser_node", parser_node)
717
+ builder.add_node("reflect_node", reflect_node)
718
+ builder.add_node("graceful_exit_node", graceful_exit_node)
719
+
720
+ builder.add_edge(START, "first_node")
721
+ builder.add_conditional_edges("first_node", get_chemtool, {
722
+ "uniprot_tool": "uniprot_node",
723
+ "list_bioactives_tool": "listbioactives_node",
724
+ "get_bioactives_tool": "getbioactives_node",
725
+ "pdb_tool": "pdb_node",
726
+ None: "parser_node"})
727
+
728
+ builder.add_conditional_edges("retry_node", get_chemtool, {
729
+ "uniprot_tool": "uniprot_node",
730
+ "list_bioactives_tool": "listbioactives_node",
731
+ "get_bioactives_tool": "getbioactives_node",
732
+ "pdb_tool": "pdb_node",
733
+ None: "parser_node"})
734
+
735
+ builder.add_edge("uniprot_node", "loop_node")
736
+ builder.add_edge("listbioactives_node", "loop_node")
737
+ builder.add_edge("getbioactives_node", "loop_node")
738
+ builder.add_edge("pdb_node", "loop_node")
739
+
740
+ builder.add_conditional_edges("loop_node", get_chemtool, {
741
+ "uniprot_tool": "uniprot_node",
742
+ "list_bioactives_tool": "listbioactives_node",
743
+ "get_bioactives_tool": "getbioactives_node",
744
+ "pdb_tool": "pdb_node",
745
+ None: "parser_node"})
746
+
747
+ builder.add_conditional_edges("parser_node", loop_or_not, {
748
+ True: "retry_node",
749
+ 'lets_get_outta_here': "graceful_exit_node",
750
+ False: "reflect_node"})
751
+
752
+ builder.add_edge("reflect_node", END)
753
+ builder.add_edge("graceful_exit_node", END)
754
+
755
+ graph = builder.compile()
756
+
757
+ @spaces.GPU
758
+ def ProteinAgent(task, protein, up_id, chembl_id, pdb_id, smiles):
759
+ input = {
760
+ "messages": [
761
+ HumanMessage(f'query_task: {task}, query_protein: {protein}, query_up_id: {up_id}, query_chembl: {chembl_id}, query_pdb: {pdb_id}, query_smiles: {smiles}')
762
+ ]
763
+ }
764
+
765
+ #if Substitution_image.png exists, remove it
766
+ if os.path.exists('Substitution_image.png'):
767
+ os.remove('Substitution_image.png')
768
+
769
+ #print(input)
770
+ replies = []
771
+ for c in graph.stream(input): #, stream_mode='updates'):
772
+ m = re.findall(r'[a-z]+\_node', str(c))
773
+ if len(m) != 0:
774
+ reply = c[str(m[0])]['messages']
775
+ if 'assistant' in str(reply):
776
+ reply = str(reply).split("<|assistant|>")[-1].split('#')[0].strip()
777
+ replies.append(reply)
778
+ #check if image exists
779
+ if os.path.exists('Substitution_image.png'):
780
+ img_loc = 'Substitution_image.png'
781
+ img = Image.open(img_loc)
782
+ #else create a dummy blank image
783
+ else:
784
+ img = Image.new('RGB', (250, 250), color = (255, 255, 255))
785
+
786
+ return replies[-1], img
787
+
788
+ with gr.Blocks(fill_height=True) as forest:
789
+ gr.Markdown('''
790
+ # Protein Agent
791
+ - calls Uniprot to find uniprot ids
792
+ - calls Chembl to find hits for a given uniprot id and reports number of bioactive molecules in the hit
793
+ - calls Chembl to find a list bioactive molecules for a given chembl id and their IC50 values
794
+ - calls PDB to find the number of chains in a protein, proteins sequences and small molecules in the structure
795
+ ''')
796
+
797
+ with gr.Row():
798
+ with gr.Column():
799
+ protein = gr.Textbox(label="Protein name of interest (optional): ", placeholder='none')
800
+ up_id = gr.Textbox(label="Uniprot ID of interest (optional): ", placeholder='none')
801
+ chembl_id = gr.Textbox(label="Chembl ID of interest (optional): ", placeholder='none')
802
+ pdb_id = gr.Textbox(label="PDB ID of interest (optional): ", placeholder='none')
803
+ smiles = gr.Textbox(label="Molecule SMILES of interest (optional): ", placeholder='none')
804
+ task = gr.Textbox(label="Task for Agent: ")
805
+ calc_btn = gr.Button(value = "Submit to Agent")
806
+ with gr.Column():
807
+ props = gr.Textbox(label="Agent results: ", lines=20 )
808
+ pic = gr.Image(label="Molecule")
809
+
810
+
811
+ calc_btn.click(ProteinAgent, inputs = [task, protein, up_id, chembl_id, pdb_id, smiles], outputs = [props, pic])
812
+ task.submit(ProteinAgent, inputs = [task, protein, up_id, chembl_id, pdb_id, smiles], outputs = [props, pic])
813
+
814
+ forest.launch(debug=False, mcp_server=True)