File size: 4,793 Bytes
3dfe804
d8bf97a
 
 
 
 
 
52c9ab4
 
 
 
 
 
245457d
d8bf97a
 
 
f3e14a9
d8bf97a
 
 
 
 
 
f3e14a9
d8bf97a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88cd151
d8bf97a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3e14a9
d8bf97a
 
 
f3e14a9
 
 
 
 
 
 
 
 
 
d8bf97a
 
f3e14a9
d8bf97a
f3e14a9
 
 
 
 
 
d8bf97a
 
f3e14a9
d8bf97a
f3e14a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
949746a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import streamlit as st
from streamlit_ketcher import st_ketcher
from stmol import showmol
import py3Dmol
import io
import ase.io as ase_io
import numpy as np
try:
    from StringIO import StringIO
except ImportError:
    from io import StringIO

import streamlit.components.v1 as components


import streamlit as st
from streamlit_ketcher import st_ketcher
WIDTH = 500
st.title("DCM-Net")
smiles = "OC(=O)[C@@H]1CCCN1"
# Create a molecule editor
smiles = st_ketcher(smiles)

# Output the SMILES string of the drawn molecule
st.write(f"SMILES input: {smiles}")


# color charges by magnitude
def color_charges(mono, nDCM, atoms):
    from matplotlib import cm
    from matplotlib.colors import Normalize

    norm = Normalize(vmin=-1, vmax=1)
    cmap = cm.get_cmap("bwr")
    mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
    elem = atoms.numbers
    pccolors = mappable.to_rgba(mono.flatten()[: len(atoms.numbers)])
    from ase.data.colors import jmol_colors

    atomcolors = [jmol_colors[_] for _ in elem]
    atomcolors_ = []
    for _ in atomcolors:
        atomcolors_.append(np.append(_, 0.015))
    dcmcolors = mappable.to_rgba(mono.flatten()[: len(elem) * nDCM])
    return dcmcolors


from dcm_app import run_dcm
# Create a 3D view
def render_3d(atoms, dcm_charges=None, dcmcolors=None):
    view = py3Dmol.view()

    view.startjs += '''\n
    let customColorize = function(atom){
        // attribute elem is from https://3dmol.csb.pitt.edu/doc/AtomSpec.html#elem
        if (atom.elem === 'X'){
            return "#FF0000" // red
        }
        else if (atom.elem === 'He'){
            return "#0000FF" // blue
        }
        else{
            return $3Dmol.getColorFromStyle(atom, {colorscheme: "whiteCarbon"});
        }
    }    
    \n'''

    if atoms is not None:
        # Write structure to xyz file.
        xyz = io.StringIO()
        ase_io.write(xyz, atoms, format='xyz')
        
        view.addModel(xyz.getvalue(), 'xyz')
        view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.25}})

    if dcm_charges is not None:
        xyz_dcm = io.StringIO()
        ase_io.write(xyz_dcm, dcm_charges, format='xyz')
        view.addModel(xyz_dcm.getvalue(), 'xyz')
        view.setStyle({'model': -1}, {'stick': {'radius': 0.01}, 'sphere': {'scale': 0.05}})
    view.startjs = view.startjs.replace('"customColorize"', 'customColorize')
    view.zoomTo()
    return view

# create a slider for n_dcm
n_dcm = st.slider("Number of Distributed Charges per Atom", min_value=1, max_value=4, value=4)

# create a button to run the DCM-Net
if st.button("Run DCM-Net"):
    # bail if smiles is not valid
    if smiles is None or smiles == "":
        st.error("Please enter a valid SMILES string")
        st.stop()

    # start a spinner
    with st.spinner("Running DCM-Net..."):
        results = run_dcm(smiles, n_dcm)
        print(results)
        ase_atoms = results["atoms"]
        col1, col2 = st.columns(2)
        with col1:
            showmol(render_3d(ase_atoms), height=300, width=WIDTH)
        with col2:
            st.image(results["smiles_image"])
    st.subheader("Model Output")
    if n_dcm >= 1:
        st.markdown("### DCM-1")
        dcmcolors1 = color_charges(results["mono_dc1"], 1, ase_atoms)
        # two columns showing combined and just the dcmol
        col1, col2 = st.columns(2)
        with col1:
            showmol(render_3d(ase_atoms, results["dcmol"]), height=300, width=WIDTH)
        with col2:
            showmol(render_3d(None, results["dcmol"], dcmcolors1), height=300, width=WIDTH)

        if n_dcm >= 2:
            st.markdown("### DCM-2")
            dcmcolors2 = color_charges(results["mono_dc2"], 2, ase_atoms)
            col1, col2 = st.columns(2)
            with col1:
                showmol(render_3d(ase_atoms, results["dcmol2"]), height=300, width=WIDTH)
            with col2:
                showmol(render_3d(None, results["dcmol2"], dcmcolors2), height=300, width=WIDTH)

        if n_dcm >= 3:
            st.markdown("### DCM-3")
            dcmcolors3 = color_charges(results["mono_dc3"], 3, ase_atoms)
            col1, col2 = st.columns(2)
            with col1:
                showmol(render_3d(ase_atoms, results["dcmol3"]), height=300, width=WIDTH)
            with col2:
                showmol(render_3d(None, results["dcmol3"], dcmcolors3), height=300, width=WIDTH)

        if n_dcm >= 4:
            st.markdown("### DCM-4")
            dcmcolors4 = color_charges(results["mono_dc4"], 4, ase_atoms)
            col1, col2 = st.columns(2)
            with col1:
                showmol(render_3d(ase_atoms, results["dcmol4"]), height=300, width=WIDTH)
            with col2:
                showmol(render_3d(None, results["dcmol4"], dcmcolors4), height=300, width=WIDTH)