|
|
from input_parsing import start_ner, start_embedding, intake, define_tool_hash, tool_descriptions_values, second_intake, parse_input
|
|
|
from input_parsing import define_tool_reqs, smiles_regex
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
import torch
|
|
|
from PIL import Image
|
|
|
import io, json, pprint as pp
|
|
|
from scholarly import scholarly, ProxyGenerator
|
|
|
import numpy as np
|
|
|
import ast
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import gradio as gr
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def start_models():
|
|
|
'''
|
|
|
Starts the necessary models for processing.
|
|
|
Returns:
|
|
|
parse_model: The NER model.
|
|
|
document_embeddings: The encoded document embeddings.
|
|
|
embed_model: The embedding model.
|
|
|
'''
|
|
|
parse_model = start_ner()
|
|
|
document_embeddings, embed_model = start_embedding(tool_descriptions_values)
|
|
|
return parse_model, document_embeddings, embed_model
|
|
|
|
|
|
class chat_manager():
|
|
|
'''
|
|
|
'''
|
|
|
def __init__(self, model_id: str = 'google/gemma-3-1b-it', device: str = 'cuda'):
|
|
|
'''
|
|
|
'''
|
|
|
self.model_id = model_id
|
|
|
self.device = device
|
|
|
self.chat_idx = 0
|
|
|
|
|
|
|
|
|
try:
|
|
|
with open('chat_session_history.txt', 'r') as f:
|
|
|
data_string = f.read()
|
|
|
self.chat_history = ast.literal_eval(data_string)
|
|
|
except:
|
|
|
self.chat_history = []
|
|
|
|
|
|
def start_model_tokenizer(self):
|
|
|
'''
|
|
|
Downloads and loads the model and tokenizer.
|
|
|
|
|
|
Args:
|
|
|
None
|
|
|
Returns:
|
|
|
None
|
|
|
Also defines:
|
|
|
model: The model to use.
|
|
|
tokenizer: The tokenizer to use.
|
|
|
'''
|
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
|
|
|
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
|
|
self.model_id, quantization_config=quantization_config).eval()
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
|
|
|
|
print(f"Model loaded on {self.device}")
|
|
|
|
|
|
def start_support_models(self):
|
|
|
'''
|
|
|
Starts the supporting models for parsing and embedding.
|
|
|
'''
|
|
|
|
|
|
self.parse_model, self.document_embeddings, self.embed_model = start_models()
|
|
|
|
|
|
def reset_chat(self):
|
|
|
'''
|
|
|
Resets the chat state.
|
|
|
'''
|
|
|
self.chat_idx = 0
|
|
|
self.best_tools = []
|
|
|
self.proteins_list = []
|
|
|
self.names_list = []
|
|
|
self.smiles_list = []
|
|
|
self.uniprot_list = []
|
|
|
self.pdb_list = []
|
|
|
self.chembl_list = []
|
|
|
self.query = ''
|
|
|
self.present = []
|
|
|
|
|
|
def hard_reset_chat(self):
|
|
|
'''
|
|
|
Resets the chat state.
|
|
|
'''
|
|
|
self.chat_idx = 0
|
|
|
self.best_tools = []
|
|
|
self.proteins_list = []
|
|
|
self.names_list = []
|
|
|
self.smiles_list = []
|
|
|
self.uniprot_list = []
|
|
|
self.pdb_list = []
|
|
|
self.chembl_list = []
|
|
|
self.query = ''
|
|
|
self.present = []
|
|
|
self.chat_history = []
|
|
|
|
|
|
def chat(self, query: str, mode_flag: str = 'AI'):
|
|
|
'''
|
|
|
Chats with the model.
|
|
|
|
|
|
Args:
|
|
|
query: The prompt to send to the model.
|
|
|
mode_flag: The mode to use (AI, Manual, Web Search, Chat).
|
|
|
Returns:
|
|
|
chat_history: The chat history.
|
|
|
'''
|
|
|
''' ===============================================================================================
|
|
|
Handle Web Search Mode -
|
|
|
if user chooses web search mode, send query to websearch node and return results
|
|
|
==================================================================================================='''
|
|
|
if mode_flag == 'Web Search':
|
|
|
self.chat_idx = 0
|
|
|
local_chat_history = []
|
|
|
local_chat_history.append(query)
|
|
|
|
|
|
top_hits, search_string, _ = websearch_node(query, self.embed_model)
|
|
|
local_chat_history.append(search_string)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
''' =============================================================================================
|
|
|
Handle Chat Mode -
|
|
|
if user chooses chat mode, send query to LLM and return response
|
|
|
================================================================================================='''
|
|
|
if mode_flag == 'Chat':
|
|
|
self.chat_idx = 0
|
|
|
local_chat_history = []
|
|
|
local_chat_history.append(query)
|
|
|
self.query = query
|
|
|
|
|
|
context = 'Previous chat history: '
|
|
|
for chat in self.chat_history:
|
|
|
for turn in chat:
|
|
|
context += '\n' + turn
|
|
|
|
|
|
|
|
|
role_text = f"You are part of a drug design agent. Answer user questions to the best of your ability. \
|
|
|
If the user asks any innapropriate questions, respond with 'I'm sorry, I can't assist with that request.' \
|
|
|
If the user asks for general information, provide a concise and accurate answer. If the user asks about drug design, \
|
|
|
provide detailed and informative answers or refer them to the tools. They can access the tools by switching to AI or \
|
|
|
manual mode. Reference the previous conversation in the context if needed."
|
|
|
|
|
|
prompt = f'Query: {self.query}; CONTEXT: {context}.'
|
|
|
|
|
|
messages = [[{
|
|
|
"role": "system",
|
|
|
"content": [{"type": "text", "text": role_text},]
|
|
|
},{
|
|
|
"role": "user",
|
|
|
"content": [{"type": "text", "text": prompt},]
|
|
|
}]]
|
|
|
|
|
|
inputs = self.tokenizer.apply_chat_template(
|
|
|
messages,
|
|
|
add_generation_prompt=True,
|
|
|
tokenize=True,
|
|
|
return_dict=True,
|
|
|
return_tensors="pt",
|
|
|
).to(self.llm_model.device)
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
outputs = self.llm_model.generate(**inputs, max_new_tokens=500, do_sample=True, temperature=0.5)
|
|
|
|
|
|
outputs = self.tokenizer.batch_decode(outputs, )
|
|
|
|
|
|
parts = outputs[0].split('<start_of_turn>model')
|
|
|
response = parts[1].strip('\n').strip('<end_of_turn>')
|
|
|
self.latest_response = response
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
'''=============================================================================================
|
|
|
AI Mode: not chat or web search -
|
|
|
First interaction: get best tools and parse entities
|
|
|
============================================================================================='''
|
|
|
if self.chat_idx == 0:
|
|
|
|
|
|
local_chat_history = []
|
|
|
local_chat_history.append(query)
|
|
|
|
|
|
'''sends the query to the intake function to get best tools and parsed entities'''
|
|
|
self.query = query
|
|
|
self.best_tools, self.present, self.proteins_list, self.names_list, self.smiles_list, self.uniprot_list, self.pdb_list, self.chembl_list = \
|
|
|
intake(self.query, self.parse_model, self.embed_model, self.document_embeddings)
|
|
|
|
|
|
response = '## The tools chosen based on your query are:'
|
|
|
for i,tool in enumerate(self.best_tools):
|
|
|
response += '\n' + f'{i+1}. {tool} : {full_tool_descriptions[tool]}'
|
|
|
|
|
|
response += ' \n\n ## And the following information was found in your query:\n'
|
|
|
for (entity_type, entity_list) in zip(self.present, [self.proteins_list, self.names_list, self.smiles_list, self.uniprot_list, self.pdb_list, self.chembl_list]):
|
|
|
if self.present[entity_type] > 0:
|
|
|
response += f'**{entity_type}**: {self.present[entity_type]}\n'
|
|
|
for ent_idx, entity in enumerate(entity_list):
|
|
|
response += f'- **{ent_idx+1}. {entity_type}**: {entity}\n'
|
|
|
response += '\n To accept the #1 tool choice, hit enter; to choose 2 or 3, enter that number.'
|
|
|
response += '\n To edit the items in a list, enter "edit".'
|
|
|
response += '\n To start over, click the clear button and enter a new query.'
|
|
|
self.chat_idx += 1
|
|
|
|
|
|
matches = smiles_regex(response)
|
|
|
for m in matches:
|
|
|
if m in response:
|
|
|
response = response.replace(m, f'```{m}```')
|
|
|
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
elif self.chat_idx == 1:
|
|
|
'''=============================================================================================
|
|
|
AI Mode: not chat or web search -
|
|
|
Second interaction: get tool choice and call tool function, return results
|
|
|
either directly (manual mode) or via LLM (AI mode)
|
|
|
============================================================================================='''
|
|
|
local_chat_history = []
|
|
|
local_chat_history.append(query)
|
|
|
|
|
|
''' ===============================================================================================
|
|
|
if the user has chosen to edit a list, go to edit list step 501
|
|
|
=============================================================================================='''
|
|
|
if 'edit' in query.lower():
|
|
|
self.chat_idx = 501
|
|
|
|
|
|
list_list = ['smiles list', 'names list', 'proteins list', 'uniprot list', 'pdb list', 'chembl list']
|
|
|
response = '## Enter the list to edit:\n'
|
|
|
for i, list_name in enumerate(list_list):
|
|
|
response += f'**{i+1}**. {list_name}\n'
|
|
|
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
elif query == '':
|
|
|
self.tool_choice = 0
|
|
|
elif query in ['1','2','3']:
|
|
|
self.tool_choice = int(query) - 1
|
|
|
else:
|
|
|
''' ===============================================================================================
|
|
|
In the case that the user enters an invalid tool choice, return to tool choice step
|
|
|
==================================================================================================='''
|
|
|
response = 'Invalid input. Please enter 1, 2, or 3 to choose one of the tools above, or hit enter to accept the #1 tool choice. \
|
|
|
or enter "edit" to edit a list.'
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
|
|
|
self.chat_idx = 1
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
''' ==============================================================================================
|
|
|
Check that the necessary data is present for the chosen tool
|
|
|
ask user for missing data if not
|
|
|
================================================================================================='''
|
|
|
tool_function_reqs = define_tool_reqs(self.best_tools[self.tool_choice], self.proteins_list,
|
|
|
self.names_list, self.smiles_list, self.uniprot_list, self.pdb_list, self.chembl_list)
|
|
|
data_request = f'The necessary data was not found for tool {self.best_tools[self.tool_choice]}.\n'
|
|
|
missing_data = False
|
|
|
reqs_list = tool_function_reqs[self.best_tools[self.tool_choice]][0]
|
|
|
list_names = tool_function_reqs[self.best_tools[self.tool_choice]][1]
|
|
|
|
|
|
for sub_list, list_name in zip(reqs_list, list_names):
|
|
|
if len(sub_list) == 0:
|
|
|
data_request += f'Missing information for: *{list_name}*.\n'
|
|
|
missing_data = True
|
|
|
data_request += 'Please provide the necessary information to proceed.'
|
|
|
if missing_data:
|
|
|
local_chat_history.append(data_request)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
self.chat_idx = 999
|
|
|
return '', self.chat_history, None
|
|
|
''' ====================================================================================================
|
|
|
End data check: if not missing data, call tool function
|
|
|
====================================================================================================='''
|
|
|
|
|
|
''' Get the chosen tool function and args, call it, and get results'''
|
|
|
tool_function_hash = define_tool_hash(self.best_tools[self.tool_choice], self.proteins_list,
|
|
|
self.names_list, self.smiles_list, self.uniprot_list, self.pdb_list, self.chembl_list)
|
|
|
|
|
|
args_list = tool_function_hash[self.best_tools[self.tool_choice]][1]
|
|
|
results_tuple = tool_function_hash[self.best_tools[self.tool_choice]][0](*args_list)
|
|
|
|
|
|
results_list, self.results_string, self.results_images = results_tuple
|
|
|
print(self.results_string)
|
|
|
|
|
|
'''=============================================================================================
|
|
|
If manual mode, return results directly; if AI mode, send results to LLM for response generation
|
|
|
============================================================================================='''
|
|
|
if mode_flag == 'Manual':
|
|
|
|
|
|
matches = smiles_regex(self.results_string)
|
|
|
|
|
|
for m in matches:
|
|
|
if m in self.results_string:
|
|
|
self.results_string = self.results_string.replace(m, f'```{m}```')
|
|
|
|
|
|
local_chat_history.append(self.results_string)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
try:
|
|
|
img = Image.open(io.BytesIO(self.results_images[0].data))
|
|
|
except:
|
|
|
img = None
|
|
|
|
|
|
self.reset_chat()
|
|
|
|
|
|
return '', self.chat_history, img
|
|
|
|
|
|
''' =============================================================================================
|
|
|
AI mode: send results to LLM for response generation
|
|
|
============================================================================================='''
|
|
|
role_text = "Answer the query using the information in the context. Add explanations \
|
|
|
or enriching information where appropriate."
|
|
|
|
|
|
prompt = f'Query: {self.query}.\n Context: {self.results_string}'
|
|
|
|
|
|
messages = [[{
|
|
|
"role": "system",
|
|
|
"content": [{"type": "text", "text": role_text},]
|
|
|
},{
|
|
|
"role": "user",
|
|
|
"content": [{"type": "text", "text": prompt},]
|
|
|
}]]
|
|
|
|
|
|
inputs = self.tokenizer.apply_chat_template(
|
|
|
messages,
|
|
|
add_generation_prompt=True,
|
|
|
tokenize=True,
|
|
|
return_dict=True,
|
|
|
return_tensors="pt",
|
|
|
).to(self.llm_model.device)
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
outputs = self.llm_model.generate(**inputs, max_new_tokens=500, do_sample=True, temperature=0.5)
|
|
|
|
|
|
outputs = self.tokenizer.batch_decode(outputs, )
|
|
|
|
|
|
parts = outputs[0].split('<start_of_turn>model')
|
|
|
response = parts[1].strip('\n').strip('<end_of_turn>')
|
|
|
self.latest_response = response
|
|
|
|
|
|
matches = smiles_regex(response)
|
|
|
for m in matches:
|
|
|
if m in response:
|
|
|
response = response.replace(m, f'```{m}```')
|
|
|
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
self.chat_idx += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
img = Image.open(io.BytesIO(self.results_images[0].data))
|
|
|
except:
|
|
|
img = None
|
|
|
|
|
|
if img != None:
|
|
|
return '', self.chat_history, img
|
|
|
else:
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
elif self.chat_idx == 2:
|
|
|
'''=============================================================================================
|
|
|
for every turn after the first tool call:
|
|
|
if chat_idx = 2, call second intake to get new tools based on latest response, last results,
|
|
|
and the new query
|
|
|
calls chat_idx = 1 again to get tool choice from user
|
|
|
============================================================================================='''
|
|
|
local_chat_history = []
|
|
|
local_chat_history.append(query)
|
|
|
self.query = query
|
|
|
|
|
|
''' =============================================================================================
|
|
|
if user wants to review history, set context to full chat history, else set context to latest response,
|
|
|
last results, and the new query
|
|
|
============================================================================================'''
|
|
|
if mode_flag == 'Review History':
|
|
|
context = self.query
|
|
|
for chat in self.chat_history:
|
|
|
for turn in chat:
|
|
|
context += '\n' + turn
|
|
|
else:
|
|
|
context = self.latest_response + '\n' + self.results_string + '\n' + self.query
|
|
|
|
|
|
self.best_tools, self.present, self.proteins_list, self.names_list, self.smiles_list, self.uniprot_list, self.pdb_list, self.chembl_list = \
|
|
|
second_intake(self.query, context, self.parse_model, self.embed_model, self.document_embeddings)
|
|
|
|
|
|
response = f'## Your new query is: {self.query}\n'
|
|
|
response += '## The tools chosen based on your query are:'
|
|
|
for i,tool in enumerate(self.best_tools):
|
|
|
response += '\n' + f'{i+1}. {tool} : {full_tool_descriptions[tool]}'
|
|
|
|
|
|
response += ' \n\n ## And the following information was found in your query:\n'
|
|
|
for (entity_type, entity_list) in zip(self.present, [self.proteins_list, self.names_list, self.smiles_list, self.uniprot_list, self.pdb_list, self.chembl_list]):
|
|
|
if self.present[entity_type] > 0:
|
|
|
response += f'**{entity_type}**: {self.present[entity_type]}\n'
|
|
|
for ent_idx, entity in enumerate(entity_list):
|
|
|
response += f'- **{ent_idx+1}. {entity_type}**: {entity}\n'
|
|
|
response += '\n To accept the #1 tool choice, hit enter; to choose 2 or 3, enter that number.'
|
|
|
response += '\n To edit the items in a list, enter "edit".'
|
|
|
response += '\n To start over, click the clear button and enter a new query.'
|
|
|
self.chat_idx = 1
|
|
|
|
|
|
matches = smiles_regex(response)
|
|
|
for m in matches:
|
|
|
if m in response:
|
|
|
response = response.replace(m, f'```{m}```')
|
|
|
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
elif self.chat_idx == 999:
|
|
|
''' ==============================================================================================
|
|
|
condition for missing data after tool choice; if user was prompted for missing data, parse new input,
|
|
|
then return to tool choice step (chat_idx = 1)
|
|
|
============================================================================================='''
|
|
|
local_chat_history = []
|
|
|
local_chat_history.append(query)
|
|
|
|
|
|
''' Parse the new input to get missing data '''
|
|
|
present, proteins_list, names_list, smiles_list, uniprot_list, pdb_list, chembl_list = parse_input(query, self.parse_model)
|
|
|
|
|
|
''' Update the existing lists with any new data only if the existing lists are empty'''
|
|
|
if len(self.proteins_list) == 0 and len(proteins_list) > 0:
|
|
|
self.proteins_list = proteins_list
|
|
|
if len(self.names_list) == 0 and len(names_list) > 0:
|
|
|
self.names_list = names_list
|
|
|
if len(self.smiles_list) == 0 and len(smiles_list) > 0:
|
|
|
self.smiles_list = smiles_list
|
|
|
if len(self.uniprot_list) == 0 and len(uniprot_list) > 0:
|
|
|
self.uniprot_list = uniprot_list
|
|
|
if len(self.pdb_list) == 0 and len(pdb_list) > 0:
|
|
|
self.pdb_list = pdb_list
|
|
|
if len(self.chembl_list) == 0 and len(chembl_list) > 0:
|
|
|
self.chembl_list = chembl_list
|
|
|
|
|
|
for item in present:
|
|
|
self.present[item] += present[item]
|
|
|
|
|
|
response = f'## Your new query is: {self.query}\n'
|
|
|
response += '## The tools chosen based on your query are:'
|
|
|
for i,tool in enumerate(self.best_tools):
|
|
|
response += '\n' + f'{i+1}. {tool} : {full_tool_descriptions[tool]}'
|
|
|
|
|
|
response += ' \n\n ## And the following information was found in your query:\n'
|
|
|
for (entity_type, entity_list) in zip(self.present, [self.proteins_list, self.names_list, self.smiles_list, self.uniprot_list, self.pdb_list, self.chembl_list]):
|
|
|
if self.present[entity_type] > 0:
|
|
|
response += f'**{entity_type}**: {self.present[entity_type]}\n'
|
|
|
for ent_idx, entity in enumerate(entity_list):
|
|
|
response += f'- **{ent_idx+1}. {entity_type}**: {entity}\n'
|
|
|
response += '\n To accept the #1 tool choice, hit enter; to choose 2 or 3, enter that number.'
|
|
|
response += '\n To edit the items in a list, enter "edit".'
|
|
|
response += '\n To start over, click the clear button and enter a new query.'
|
|
|
self.chat_idx = 1
|
|
|
|
|
|
matches = smiles_regex(response)
|
|
|
for m in matches:
|
|
|
if m in response:
|
|
|
response = response.replace(m, f'```{m}```')
|
|
|
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
elif self.chat_idx == 501:
|
|
|
''' ==============================================================================================
|
|
|
condition for editing a list; get which list to edit from user
|
|
|
============================================================================================='''
|
|
|
local_chat_history = []
|
|
|
local_chat_history.append(query)
|
|
|
self.chat_idx = 502
|
|
|
|
|
|
list_list = ['smiles list', 'names list', 'proteins list', 'uniprot list', 'pdb list', 'chembl list']
|
|
|
try:
|
|
|
choice_idx = int(query) - 1
|
|
|
self.list_to_edit = list_list[choice_idx]
|
|
|
|
|
|
response = f'## You have chosen to edit the {self.list_to_edit}.\n'
|
|
|
response += 'Enter the numbers for the *items to keep* in the list.'
|
|
|
|
|
|
except:
|
|
|
response = 'Invalid input. Please enter the number corresponding to the list you wish to edit.'
|
|
|
self.chat_idx = 501
|
|
|
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
if self.chat_idx == 502:
|
|
|
''' ==============================================================================================
|
|
|
condition for editing a list; get which items to keep from user
|
|
|
============================================================================================='''
|
|
|
local_chat_history = []
|
|
|
local_chat_history.append(query)
|
|
|
self.chat_idx = 503
|
|
|
|
|
|
if ',' in query:
|
|
|
items_to_keep = query.split(',')
|
|
|
elif ';' in query:
|
|
|
items_to_keep = query.split(';')
|
|
|
else:
|
|
|
items_to_keep = query.split()
|
|
|
|
|
|
for item in items_to_keep:
|
|
|
if not item.isdigit():
|
|
|
response = 'Invalid input. Please enter the numbers corresponding to the items you wish to keep.'
|
|
|
self.chat_idx = 502
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
try:
|
|
|
if self.list_to_edit == 'smiles list':
|
|
|
current_list = self.smiles_list
|
|
|
elif self.list_to_edit == 'names list':
|
|
|
current_list = self.names_list
|
|
|
elif self.list_to_edit == 'proteins list':
|
|
|
current_list = self.proteins_list
|
|
|
elif self.list_to_edit == 'uniprot list':
|
|
|
current_list = self.uniprot_list
|
|
|
elif self.list_to_edit == 'pdb list':
|
|
|
current_list = self.pdb_list
|
|
|
elif self.list_to_edit == 'chembl list':
|
|
|
current_list = self.chembl_list
|
|
|
|
|
|
new_list = []
|
|
|
for item in items_to_keep:
|
|
|
idx = int(item) - 1
|
|
|
new_list.append(current_list[idx])
|
|
|
|
|
|
if self.list_to_edit == 'smiles list':
|
|
|
self.smiles_list = new_list
|
|
|
elif self.list_to_edit == 'names list':
|
|
|
self.names_list = new_list
|
|
|
elif self.list_to_edit == 'proteins list':
|
|
|
self.proteins_list = new_list
|
|
|
elif self.list_to_edit == 'uniprot list':
|
|
|
self.uniprot_list = new_list
|
|
|
elif self.list_to_edit == 'pdb list':
|
|
|
self.pdb_list = new_list
|
|
|
elif self.list_to_edit == 'chembl list':
|
|
|
self.chembl_list = new_list
|
|
|
|
|
|
self.present = {
|
|
|
'proteins': len(self.proteins_list),
|
|
|
'molecules': len(self.names_list),
|
|
|
'smiles': len(self.smiles_list),
|
|
|
'uniprot': len(self.uniprot_list),
|
|
|
'pdb': len(self.pdb_list),
|
|
|
'chembl': len(self.chembl_list)
|
|
|
}
|
|
|
|
|
|
response = '## The tools chosen based on your query are:'
|
|
|
for i,tool in enumerate(self.best_tools):
|
|
|
response += '\n' + f'{i+1}. {tool} : {full_tool_descriptions[tool]}'
|
|
|
|
|
|
response += ' \n\n ## And the following information was found in your query:\n'
|
|
|
for (entity_type, entity_list) in zip(self.present, [self.proteins_list, self.names_list, self.smiles_list, self.uniprot_list, self.pdb_list, self.chembl_list]):
|
|
|
if self.present[entity_type] > 0:
|
|
|
response += f'**{entity_type}**: {self.present[entity_type]}\n'
|
|
|
for ent_idx, entity in enumerate(entity_list):
|
|
|
response += f'- **{ent_idx+1}. {entity_type}**: {entity}\n'
|
|
|
response += '\n To accept the #1 tool choice, hit enter; to choose 2 or 3, enter that number.'
|
|
|
response += '\n To edit the items in a list, enter "edit".'
|
|
|
response += '\n To start over, click the clear button and enter a new query.'
|
|
|
self.chat_idx = 1
|
|
|
|
|
|
matches = smiles_regex(response)
|
|
|
for m in matches:
|
|
|
if m in response:
|
|
|
response = response.replace(m, f'```{m}```')
|
|
|
|
|
|
except:
|
|
|
response = 'An error occurred while processing your input. Please try again.'
|
|
|
self.chat_idx = 502
|
|
|
|
|
|
local_chat_history.append(response)
|
|
|
self.chat_history.append(local_chat_history)
|
|
|
with open('chat_session_history.txt', 'w') as f:
|
|
|
pp.pp(self.chat_history, stream=f)
|
|
|
|
|
|
return '', self.chat_history, None
|
|
|
|
|
|
|
|
|
|
|
|
full_tool_descriptions = {
|
|
|
'smiles_node' : 'Queries Pubchem for the smiles string of the molecule based on the name.',
|
|
|
'name_node' : 'Queries Pubchem for the name of the molecule based on the smiles string.',
|
|
|
'related_node' : 'Queries Pubchem for similar molecules based on the smiles string or name.',
|
|
|
'substitution_node' : 'A simple substitution routine that looks for a substituent on a phenyl ring and\
|
|
|
substitutes different fragments in that location. Returns a list of novel molecules and their\
|
|
|
QED score (1 is most drug-like, 0 is least drug-like).',
|
|
|
'lipinski_node' : 'A tool to calculate QED and other lipinski properties of a molecule.',
|
|
|
'pharmfeature_node': 'A tool to compare the pharmacophore features of a query molecule against\
|
|
|
those of a reference molecule and report the pharmacophore features of both and the feature\
|
|
|
score of the query molecule.',
|
|
|
'uniprot_node' : 'This tool takes in the user requested protein and searches UNIPROT for matches.\
|
|
|
It returns a string scontaining the protein ID, gene name, organism, and protein name.',
|
|
|
'listbioactives_node' : 'Accepts a UNIPROT ID and searches for bioactive molecules. Returns counts of\
|
|
|
the bioactives found and their ChEMBL IDs.',
|
|
|
'getbioactives_node' : 'Accepts a Chembl ID and get all bioactives molecule SMILES and IC50s for that ID.',
|
|
|
'predict_node' : 'uses the current_bioactives.csv file from the get_bioactives node to fit the\
|
|
|
Light GBM model and predict the IC50 for the current smiles.',
|
|
|
'gpt_node' : 'Uses a Chembl dataset, previously stored in a CSV file by the get_bioactives node, to\
|
|
|
finetune a GPT model to generate novel molecules for the target protein.',
|
|
|
'pdb_node' : 'Accepts a PDB ID and queires the protein databank for the sequence of the protein, \
|
|
|
as well as other information such as ligands.',
|
|
|
'find_node': 'Accepts a protein name and searches the protein databack for PDB IDs that match along \
|
|
|
with the entry titles.',
|
|
|
'docking_node' : 'Docking tool: uses dockstring to dock the molecule into the protein binding site and returns \
|
|
|
the docking score and the binding pose.',
|
|
|
'get_actives_for_protein' : 'Finds Bioactive molecules for a give protein. Uses Uniprot to find chembl IDs \
|
|
|
for the protein, and then queries chembl for bioactive molecules.',
|
|
|
'get_predictions_for_protein' : 'Uses Uniprot to find chembl IDs for the protein, and then queries chembl \
|
|
|
for bioactive molecules to train a model and predict the activity of the given smiles.',
|
|
|
'dock_from_names' : 'Accepts names of molecules and docks them in a given protein.'
|
|
|
}
|
|
|
|
|
|
def websearch_node(query: str, embed_model, proxy_flag: bool = True) -> (list[str], str, list):
|
|
|
'''
|
|
|
Performs a web search using scholarly and ranks results based on similarity to the query.
|
|
|
Args:
|
|
|
query (str): The input query string.
|
|
|
embed_model: The embedding model.
|
|
|
Returns:
|
|
|
top_hits (list[str]): List of top hit titles and links.
|
|
|
search_string (str): String representation of the top hits.
|
|
|
None: Placeholder for images (not used here).
|
|
|
'''
|
|
|
try:
|
|
|
if proxy_flag:
|
|
|
pg = ProxyGenerator()
|
|
|
success = pg.FreeProxies()
|
|
|
if success:
|
|
|
pg.FreeProxies()
|
|
|
scholarly.use_proxy(pg)
|
|
|
|
|
|
scholarly.set_timeout(15)
|
|
|
|
|
|
search_query = scholarly.search_pubs(query)
|
|
|
print(f'Search generator created for query: {query}')
|
|
|
|
|
|
titles = []
|
|
|
links = []
|
|
|
abstracts = []
|
|
|
|
|
|
for i in range(10):
|
|
|
item = next(search_query)
|
|
|
res_string = json.dumps(item)
|
|
|
res_dict = json.loads(res_string)
|
|
|
links.append(res_dict['pub_url'])
|
|
|
titles.append(res_dict['bib']['title'])
|
|
|
abstracts.append(res_dict['bib']['abstract'])
|
|
|
print(f'Found result {i+1}')
|
|
|
|
|
|
assert(len(titles) == len(links) == len(abstracts))
|
|
|
print(f'Found {len(titles)} results')
|
|
|
|
|
|
abstract_embeddings = embed_model.encode_document(abstracts)
|
|
|
query_embeddings = embed_model.encode_query(query)
|
|
|
|
|
|
scores = embed_model.similarity(query_embeddings, abstract_embeddings)
|
|
|
|
|
|
max_hits = 10
|
|
|
if len(scores) < max_hits:
|
|
|
max_hits = len(scores)
|
|
|
top_hits = []
|
|
|
hits_idx = 0
|
|
|
while hits_idx < 10:
|
|
|
current_hit_idx = np.argmax(scores[0])
|
|
|
current_score = scores[0][current_hit_idx].item()
|
|
|
top_hits.append((titles[current_hit_idx], links[current_hit_idx], current_score))
|
|
|
scores[0][current_hit_idx] = -1
|
|
|
hits_idx += 1
|
|
|
|
|
|
search_string = f'The top 10 hits for your query are:\n'
|
|
|
i = 0
|
|
|
for title, link, score in top_hits:
|
|
|
search_string += f'{i}. {title}\nLink: {link}\nScore: {score:.3f}\n\n'
|
|
|
i += 1
|
|
|
print('Web search completed successfully.')
|
|
|
except:
|
|
|
top_hits = []
|
|
|
search_string = 'Web search failed. Please try again later.'
|
|
|
print('Web search failed due to an exception.')
|
|
|
|
|
|
return top_hits, search_string, None
|
|
|
|
|
|
'''================================================================================================
|
|
|
Functions for use on Hugging Face Spaces
|
|
|
===================================================================================================='''
|
|
|
|
|
|
|
|
|
''' ======================================================================================================
|
|
|
older functions retained for compatibility
|
|
|
======================================================================================================'''
|
|
|
|
|
|
|
|
|
def query_to_context(query: str, parse_model, embed_model, document_embeddings):
|
|
|
'''
|
|
|
Processes a query to extract relevant context and information.
|
|
|
Args:
|
|
|
query (str): The input query string.
|
|
|
parse_model: The NER model.
|
|
|
embed_model: The embedding model.
|
|
|
document_embeddings: The encoded document embeddings.
|
|
|
Returns:
|
|
|
results_list: List of results.
|
|
|
results_string: String representation of results.
|
|
|
results_images: Any associated images with the results.
|
|
|
'''
|
|
|
|
|
|
best_tools, present, proteins_list, names_list, smiles_list, uniprot_list, pdb_list, chembl_list = intake(query, parse_model, embed_model, document_embeddings)
|
|
|
tool_function_hash = define_tool_hash(best_tools[0], proteins_list, names_list, smiles_list, uniprot_list, pdb_list, chembl_list)
|
|
|
|
|
|
args_list = tool_function_hash[best_tools[0]][1]
|
|
|
results_tuple = tool_function_hash[best_tools[0]][0](*args_list)
|
|
|
|
|
|
i=1
|
|
|
while results_tuple[0] == [] :
|
|
|
tool_function_hash = define_tool_hash(best_tools[i], proteins_list, names_list, smiles_list, uniprot_list, pdb_list, chembl_list)
|
|
|
args_list = tool_function_hash[best_tools[i]][1]
|
|
|
results_tuple = tool_function_hash[best_tools[i]][0](*args_list)
|
|
|
i+=1
|
|
|
if i == 3:
|
|
|
break
|
|
|
|
|
|
results_list, results_string, results_images = results_tuple
|
|
|
|
|
|
return results_list, results_string, results_images
|
|
|
|
|
|
|