Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,6 +27,9 @@ import pubchempy as pcp
|
|
| 27 |
import gradio as gr
|
| 28 |
from PIL import Image
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
|
| 32 |
hf = HuggingFacePipeline.from_model_id(
|
|
@@ -36,3 +39,97 @@ hf = HuggingFacePipeline.from_model_id(
|
|
| 36 |
|
| 37 |
chat_model = ChatHuggingFace(llm=hf)
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
import gradio as gr
|
| 28 |
from PIL import Image
|
| 29 |
|
| 30 |
+
from chem_nodes import *
|
| 31 |
+
from agent_nodes import *
|
| 32 |
+
|
| 33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
|
| 35 |
hf = HuggingFacePipeline.from_model_id(
|
|
|
|
| 39 |
|
| 40 |
chat_model = ChatHuggingFace(llm=hf)
|
| 41 |
|
| 42 |
+
builder = StateGraph(State)
|
| 43 |
+
builder.add_node("first_node", first_node)
|
| 44 |
+
builder.add_node("retry_node", retry_node)
|
| 45 |
+
builder.add_node("smiles_node", smiles_node)
|
| 46 |
+
builder.add_node("name_node", name_node)
|
| 47 |
+
builder.add_node("similars_node", similars_node)
|
| 48 |
+
builder.add_node("loop_node", loop_node)
|
| 49 |
+
builder.add_node("parser_node", parser_node)
|
| 50 |
+
builder.add_node("reflect_node", reflect_node)
|
| 51 |
+
|
| 52 |
+
builder.add_edge(START, "first_node")
|
| 53 |
+
builder.add_conditional_edges("first_node", get_chemtool, {
|
| 54 |
+
"smiles_tool": "smiles_node",
|
| 55 |
+
"name_tool": "name_node",
|
| 56 |
+
"similars_tool": "similars_node",
|
| 57 |
+
None: "parser_node"})
|
| 58 |
+
|
| 59 |
+
builder.add_conditional_edges("retry_node", get_chemtool, {
|
| 60 |
+
"smiles_tool": "smiles_node",
|
| 61 |
+
"name_tool": "name_node",
|
| 62 |
+
"similars_tool": "similars_node",
|
| 63 |
+
None: "parser_node"})
|
| 64 |
+
|
| 65 |
+
builder.add_edge("smiles_node", "loop_node")
|
| 66 |
+
builder.add_edge("name_node", "loop_node")
|
| 67 |
+
builder.add_edge("similars_node", "loop_node")
|
| 68 |
+
|
| 69 |
+
builder.add_conditional_edges("loop_node", get_chemtool, {
|
| 70 |
+
"smiles_tool": "smiles_node",
|
| 71 |
+
"name_tool": "name_node",
|
| 72 |
+
"similars_tool": "similars_node",
|
| 73 |
+
"loop_again": "first_node",
|
| 74 |
+
None: "parser_node"})
|
| 75 |
+
|
| 76 |
+
builder.add_conditional_edges("parser_node", loop_or_not, {
|
| 77 |
+
True: "retry_node",
|
| 78 |
+
False: "reflect_node"})
|
| 79 |
+
|
| 80 |
+
builder.add_edge("reflect_node", END)
|
| 81 |
+
|
| 82 |
+
graph = builder.compile()
|
| 83 |
+
|
| 84 |
+
def MoleculeAgent(smiles, name, task):
|
| 85 |
+
|
| 86 |
+
#if Similars_image.png exists, remove it
|
| 87 |
+
if os.path.exists('Similars_image.png'):
|
| 88 |
+
os.remove('Similars_image.png')
|
| 89 |
+
|
| 90 |
+
input = {
|
| 91 |
+
"messages": [
|
| 92 |
+
HumanMessage(f'query_smiles: {smiles}, query_task: {task}, query_name: {name}')
|
| 93 |
+
]
|
| 94 |
+
}
|
| 95 |
+
#print(input)
|
| 96 |
+
|
| 97 |
+
replies = []
|
| 98 |
+
for c in graph.stream(input): #, stream_mode='updates'):
|
| 99 |
+
m = re.findall(r'[a-z]+\_node', str(c))
|
| 100 |
+
if len(m) != 0:
|
| 101 |
+
reply = c[str(m[0])]['messages']
|
| 102 |
+
if 'assistant' in str(reply):
|
| 103 |
+
reply = str(reply).split("<|assistant|>")[-1].split('#')[0].strip()
|
| 104 |
+
replies.append(reply)
|
| 105 |
+
#check if image exists
|
| 106 |
+
if os.path.exists('Similars_image.png'):
|
| 107 |
+
img_loc = 'Similars_image.png'
|
| 108 |
+
img = Image.open(img_loc)
|
| 109 |
+
#else create a dummy blank image
|
| 110 |
+
else:
|
| 111 |
+
img = Image.new('RGB', (250, 250), color = (255, 255, 255))
|
| 112 |
+
|
| 113 |
+
return replies[-1], img
|
| 114 |
+
|
| 115 |
+
with gr.Blocks(fill_height=True) as forest:
|
| 116 |
+
gr.Markdown('''
|
| 117 |
+
# Molecule Agent - fetches names and SMILES from PubChem or finds
|
| 118 |
+
similar molecules.
|
| 119 |
+
''')
|
| 120 |
+
|
| 121 |
+
name, smiles = None, None
|
| 122 |
+
with gr.Row():
|
| 123 |
+
with gr.Column():
|
| 124 |
+
smiles = gr.Textbox(label="Molecule SMILES of interest (optional): ", placeholder='none')
|
| 125 |
+
name = gr.Textbox(label="Molecule Name of interest (optional): ", placeholder='none')
|
| 126 |
+
task = gr.Textbox(label="Task for Agent: ")
|
| 127 |
+
calc_btn = gr.Button(value = "Submit to Agent")
|
| 128 |
+
with gr.Column():
|
| 129 |
+
props = gr.Textbox(label="Agent results: ", lines=20 )
|
| 130 |
+
pic = gr.Image(label="Molecule")
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
calc_btn.click(MoleculeAgent, inputs = [smiles, name, task], outputs = [props, pic])
|
| 134 |
+
|
| 135 |
+
forest.launch(debug=True, share=True)
|