File size: 4,793 Bytes
3dfe804 d8bf97a 52c9ab4 245457d d8bf97a f3e14a9 d8bf97a f3e14a9 d8bf97a 88cd151 d8bf97a f3e14a9 d8bf97a f3e14a9 d8bf97a f3e14a9 d8bf97a f3e14a9 d8bf97a f3e14a9 d8bf97a f3e14a9 949746a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | import streamlit as st
from streamlit_ketcher import st_ketcher
from stmol import showmol
import py3Dmol
import io
import ase.io as ase_io
import numpy as np
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import streamlit.components.v1 as components
import streamlit as st
from streamlit_ketcher import st_ketcher
WIDTH = 500
st.title("DCM-Net")
smiles = "OC(=O)[C@@H]1CCCN1"
# Create a molecule editor
smiles = st_ketcher(smiles)
# Output the SMILES string of the drawn molecule
st.write(f"SMILES input: {smiles}")
# color charges by magnitude
def color_charges(mono, nDCM, atoms):
from matplotlib import cm
from matplotlib.colors import Normalize
norm = Normalize(vmin=-1, vmax=1)
cmap = cm.get_cmap("bwr")
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
elem = atoms.numbers
pccolors = mappable.to_rgba(mono.flatten()[: len(atoms.numbers)])
from ase.data.colors import jmol_colors
atomcolors = [jmol_colors[_] for _ in elem]
atomcolors_ = []
for _ in atomcolors:
atomcolors_.append(np.append(_, 0.015))
dcmcolors = mappable.to_rgba(mono.flatten()[: len(elem) * nDCM])
return dcmcolors
from dcm_app import run_dcm
# Create a 3D view
def render_3d(atoms, dcm_charges=None, dcmcolors=None):
view = py3Dmol.view()
view.startjs += '''\n
let customColorize = function(atom){
// attribute elem is from https://3dmol.csb.pitt.edu/doc/AtomSpec.html#elem
if (atom.elem === 'X'){
return "#FF0000" // red
}
else if (atom.elem === 'He'){
return "#0000FF" // blue
}
else{
return $3Dmol.getColorFromStyle(atom, {colorscheme: "whiteCarbon"});
}
}
\n'''
if atoms is not None:
# Write structure to xyz file.
xyz = io.StringIO()
ase_io.write(xyz, atoms, format='xyz')
view.addModel(xyz.getvalue(), 'xyz')
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.25}})
if dcm_charges is not None:
xyz_dcm = io.StringIO()
ase_io.write(xyz_dcm, dcm_charges, format='xyz')
view.addModel(xyz_dcm.getvalue(), 'xyz')
view.setStyle({'model': -1}, {'stick': {'radius': 0.01}, 'sphere': {'scale': 0.05}})
view.startjs = view.startjs.replace('"customColorize"', 'customColorize')
view.zoomTo()
return view
# create a slider for n_dcm
n_dcm = st.slider("Number of Distributed Charges per Atom", min_value=1, max_value=4, value=4)
# create a button to run the DCM-Net
if st.button("Run DCM-Net"):
# bail if smiles is not valid
if smiles is None or smiles == "":
st.error("Please enter a valid SMILES string")
st.stop()
# start a spinner
with st.spinner("Running DCM-Net..."):
results = run_dcm(smiles, n_dcm)
print(results)
ase_atoms = results["atoms"]
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms), height=300, width=WIDTH)
with col2:
st.image(results["smiles_image"])
st.subheader("Model Output")
if n_dcm >= 1:
st.markdown("### DCM-1")
dcmcolors1 = color_charges(results["mono_dc1"], 1, ase_atoms)
# two columns showing combined and just the dcmol
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms, results["dcmol"]), height=300, width=WIDTH)
with col2:
showmol(render_3d(None, results["dcmol"], dcmcolors1), height=300, width=WIDTH)
if n_dcm >= 2:
st.markdown("### DCM-2")
dcmcolors2 = color_charges(results["mono_dc2"], 2, ase_atoms)
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms, results["dcmol2"]), height=300, width=WIDTH)
with col2:
showmol(render_3d(None, results["dcmol2"], dcmcolors2), height=300, width=WIDTH)
if n_dcm >= 3:
st.markdown("### DCM-3")
dcmcolors3 = color_charges(results["mono_dc3"], 3, ase_atoms)
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms, results["dcmol3"]), height=300, width=WIDTH)
with col2:
showmol(render_3d(None, results["dcmol3"], dcmcolors3), height=300, width=WIDTH)
if n_dcm >= 4:
st.markdown("### DCM-4")
dcmcolors4 = color_charges(results["mono_dc4"], 4, ase_atoms)
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms, results["dcmol4"]), height=300, width=WIDTH)
with col2:
showmol(render_3d(None, results["dcmol4"], dcmcolors4), height=300, width=WIDTH)
|