cafierom commited on
Commit
fe606d4
·
verified ·
1 Parent(s): acf312d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py CHANGED
@@ -27,6 +27,9 @@ import pubchempy as pcp
27
  import gradio as gr
28
  from PIL import Image
29
 
 
 
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
  hf = HuggingFacePipeline.from_model_id(
@@ -36,3 +39,97 @@ hf = HuggingFacePipeline.from_model_id(
36
 
37
  chat_model = ChatHuggingFace(llm=hf)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  import gradio as gr
28
  from PIL import Image
29
 
30
+ from chem_nodes import *
31
+ from agent_nodes import *
32
+
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
  hf = HuggingFacePipeline.from_model_id(
 
39
 
40
  chat_model = ChatHuggingFace(llm=hf)
41
 
42
+ builder = StateGraph(State)
43
+ builder.add_node("first_node", first_node)
44
+ builder.add_node("retry_node", retry_node)
45
+ builder.add_node("smiles_node", smiles_node)
46
+ builder.add_node("name_node", name_node)
47
+ builder.add_node("similars_node", similars_node)
48
+ builder.add_node("loop_node", loop_node)
49
+ builder.add_node("parser_node", parser_node)
50
+ builder.add_node("reflect_node", reflect_node)
51
+
52
+ builder.add_edge(START, "first_node")
53
+ builder.add_conditional_edges("first_node", get_chemtool, {
54
+ "smiles_tool": "smiles_node",
55
+ "name_tool": "name_node",
56
+ "similars_tool": "similars_node",
57
+ None: "parser_node"})
58
+
59
+ builder.add_conditional_edges("retry_node", get_chemtool, {
60
+ "smiles_tool": "smiles_node",
61
+ "name_tool": "name_node",
62
+ "similars_tool": "similars_node",
63
+ None: "parser_node"})
64
+
65
+ builder.add_edge("smiles_node", "loop_node")
66
+ builder.add_edge("name_node", "loop_node")
67
+ builder.add_edge("similars_node", "loop_node")
68
+
69
+ builder.add_conditional_edges("loop_node", get_chemtool, {
70
+ "smiles_tool": "smiles_node",
71
+ "name_tool": "name_node",
72
+ "similars_tool": "similars_node",
73
+ "loop_again": "first_node",
74
+ None: "parser_node"})
75
+
76
+ builder.add_conditional_edges("parser_node", loop_or_not, {
77
+ True: "retry_node",
78
+ False: "reflect_node"})
79
+
80
+ builder.add_edge("reflect_node", END)
81
+
82
+ graph = builder.compile()
83
+
84
+ def MoleculeAgent(smiles, name, task):
85
+
86
+ #if Similars_image.png exists, remove it
87
+ if os.path.exists('Similars_image.png'):
88
+ os.remove('Similars_image.png')
89
+
90
+ input = {
91
+ "messages": [
92
+ HumanMessage(f'query_smiles: {smiles}, query_task: {task}, query_name: {name}')
93
+ ]
94
+ }
95
+ #print(input)
96
+
97
+ replies = []
98
+ for c in graph.stream(input): #, stream_mode='updates'):
99
+ m = re.findall(r'[a-z]+\_node', str(c))
100
+ if len(m) != 0:
101
+ reply = c[str(m[0])]['messages']
102
+ if 'assistant' in str(reply):
103
+ reply = str(reply).split("<|assistant|>")[-1].split('#')[0].strip()
104
+ replies.append(reply)
105
+ #check if image exists
106
+ if os.path.exists('Similars_image.png'):
107
+ img_loc = 'Similars_image.png'
108
+ img = Image.open(img_loc)
109
+ #else create a dummy blank image
110
+ else:
111
+ img = Image.new('RGB', (250, 250), color = (255, 255, 255))
112
+
113
+ return replies[-1], img
114
+
115
+ with gr.Blocks(fill_height=True) as forest:
116
+ gr.Markdown('''
117
+ # Molecule Agent - fetches names and SMILES from PubChem or finds
118
+ similar molecules.
119
+ ''')
120
+
121
+ name, smiles = None, None
122
+ with gr.Row():
123
+ with gr.Column():
124
+ smiles = gr.Textbox(label="Molecule SMILES of interest (optional): ", placeholder='none')
125
+ name = gr.Textbox(label="Molecule Name of interest (optional): ", placeholder='none')
126
+ task = gr.Textbox(label="Task for Agent: ")
127
+ calc_btn = gr.Button(value = "Submit to Agent")
128
+ with gr.Column():
129
+ props = gr.Textbox(label="Agent results: ", lines=20 )
130
+ pic = gr.Image(label="Molecule")
131
+
132
+
133
+ calc_btn.click(MoleculeAgent, inputs = [smiles, name, task], outputs = [props, pic])
134
+
135
+ forest.launch(debug=True, share=True)