DCMNet / app.py
EricBoi's picture
Refactor app.py to improve user experience by updating SMILES input display, enhancing error handling for invalid inputs, and refining the layout for displaying DCM layers. Adjusted slider label for clarity and ensured consistent rendering dimensions for 3D models.
f3e14a9
import streamlit as st
from streamlit_ketcher import st_ketcher
from stmol import showmol
import py3Dmol
import io
import ase.io as ase_io
import numpy as np
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import streamlit.components.v1 as components
import streamlit as st
from streamlit_ketcher import st_ketcher
WIDTH = 500
st.title("DCM-Net")
smiles = "OC(=O)[C@@H]1CCCN1"
# Create a molecule editor
smiles = st_ketcher(smiles)
# Output the SMILES string of the drawn molecule
st.write(f"SMILES input: {smiles}")
# color charges by magnitude
def color_charges(mono, nDCM, atoms):
from matplotlib import cm
from matplotlib.colors import Normalize
norm = Normalize(vmin=-1, vmax=1)
cmap = cm.get_cmap("bwr")
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
elem = atoms.numbers
pccolors = mappable.to_rgba(mono.flatten()[: len(atoms.numbers)])
from ase.data.colors import jmol_colors
atomcolors = [jmol_colors[_] for _ in elem]
atomcolors_ = []
for _ in atomcolors:
atomcolors_.append(np.append(_, 0.015))
dcmcolors = mappable.to_rgba(mono.flatten()[: len(elem) * nDCM])
return dcmcolors
from dcm_app import run_dcm
# Create a 3D view
def render_3d(atoms, dcm_charges=None, dcmcolors=None):
view = py3Dmol.view()
view.startjs += '''\n
let customColorize = function(atom){
// attribute elem is from https://3dmol.csb.pitt.edu/doc/AtomSpec.html#elem
if (atom.elem === 'X'){
return "#FF0000" // red
}
else if (atom.elem === 'He'){
return "#0000FF" // blue
}
else{
return $3Dmol.getColorFromStyle(atom, {colorscheme: "whiteCarbon"});
}
}
\n'''
if atoms is not None:
# Write structure to xyz file.
xyz = io.StringIO()
ase_io.write(xyz, atoms, format='xyz')
view.addModel(xyz.getvalue(), 'xyz')
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.25}})
if dcm_charges is not None:
xyz_dcm = io.StringIO()
ase_io.write(xyz_dcm, dcm_charges, format='xyz')
view.addModel(xyz_dcm.getvalue(), 'xyz')
view.setStyle({'model': -1}, {'stick': {'radius': 0.01}, 'sphere': {'scale': 0.05}})
view.startjs = view.startjs.replace('"customColorize"', 'customColorize')
view.zoomTo()
return view
# create a slider for n_dcm
n_dcm = st.slider("Number of Distributed Charges per Atom", min_value=1, max_value=4, value=4)
# create a button to run the DCM-Net
if st.button("Run DCM-Net"):
# bail if smiles is not valid
if smiles is None or smiles == "":
st.error("Please enter a valid SMILES string")
st.stop()
# start a spinner
with st.spinner("Running DCM-Net..."):
results = run_dcm(smiles, n_dcm)
print(results)
ase_atoms = results["atoms"]
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms), height=300, width=WIDTH)
with col2:
st.image(results["smiles_image"])
st.subheader("Model Output")
if n_dcm >= 1:
st.markdown("### DCM-1")
dcmcolors1 = color_charges(results["mono_dc1"], 1, ase_atoms)
# two columns showing combined and just the dcmol
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms, results["dcmol"]), height=300, width=WIDTH)
with col2:
showmol(render_3d(None, results["dcmol"], dcmcolors1), height=300, width=WIDTH)
if n_dcm >= 2:
st.markdown("### DCM-2")
dcmcolors2 = color_charges(results["mono_dc2"], 2, ase_atoms)
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms, results["dcmol2"]), height=300, width=WIDTH)
with col2:
showmol(render_3d(None, results["dcmol2"], dcmcolors2), height=300, width=WIDTH)
if n_dcm >= 3:
st.markdown("### DCM-3")
dcmcolors3 = color_charges(results["mono_dc3"], 3, ase_atoms)
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms, results["dcmol3"]), height=300, width=WIDTH)
with col2:
showmol(render_3d(None, results["dcmol3"], dcmcolors3), height=300, width=WIDTH)
if n_dcm >= 4:
st.markdown("### DCM-4")
dcmcolors4 = color_charges(results["mono_dc4"], 4, ase_atoms)
col1, col2 = st.columns(2)
with col1:
showmol(render_3d(ase_atoms, results["dcmol4"]), height=300, width=WIDTH)
with col2:
showmol(render_3d(None, results["dcmol4"], dcmcolors4), height=300, width=WIDTH)