import streamlit as st from streamlit_ketcher import st_ketcher from stmol import showmol import py3Dmol import io import ase.io as ase_io import numpy as np try: from StringIO import StringIO except ImportError: from io import StringIO import streamlit.components.v1 as components import streamlit as st from streamlit_ketcher import st_ketcher WIDTH = 500 st.title("DCM-Net") smiles = "OC(=O)[C@@H]1CCCN1" # Create a molecule editor smiles = st_ketcher(smiles) # Output the SMILES string of the drawn molecule st.write(f"SMILES input: {smiles}") # color charges by magnitude def color_charges(mono, nDCM, atoms): from matplotlib import cm from matplotlib.colors import Normalize norm = Normalize(vmin=-1, vmax=1) cmap = cm.get_cmap("bwr") mappable = cm.ScalarMappable(norm=norm, cmap=cmap) elem = atoms.numbers pccolors = mappable.to_rgba(mono.flatten()[: len(atoms.numbers)]) from ase.data.colors import jmol_colors atomcolors = [jmol_colors[_] for _ in elem] atomcolors_ = [] for _ in atomcolors: atomcolors_.append(np.append(_, 0.015)) dcmcolors = mappable.to_rgba(mono.flatten()[: len(elem) * nDCM]) return dcmcolors from dcm_app import run_dcm # Create a 3D view def render_3d(atoms, dcm_charges=None, dcmcolors=None): view = py3Dmol.view() view.startjs += '''\n let customColorize = function(atom){ // attribute elem is from https://3dmol.csb.pitt.edu/doc/AtomSpec.html#elem if (atom.elem === 'X'){ return "#FF0000" // red } else if (atom.elem === 'He'){ return "#0000FF" // blue } else{ return $3Dmol.getColorFromStyle(atom, {colorscheme: "whiteCarbon"}); } } \n''' if atoms is not None: # Write structure to xyz file. xyz = io.StringIO() ase_io.write(xyz, atoms, format='xyz') view.addModel(xyz.getvalue(), 'xyz') view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.25}}) if dcm_charges is not None: xyz_dcm = io.StringIO() ase_io.write(xyz_dcm, dcm_charges, format='xyz') view.addModel(xyz_dcm.getvalue(), 'xyz') view.setStyle({'model': -1}, {'stick': {'radius': 0.01}, 'sphere': {'scale': 0.05}}) view.startjs = view.startjs.replace('"customColorize"', 'customColorize') view.zoomTo() return view # create a slider for n_dcm n_dcm = st.slider("Number of Distributed Charges per Atom", min_value=1, max_value=4, value=4) # create a button to run the DCM-Net if st.button("Run DCM-Net"): # bail if smiles is not valid if smiles is None or smiles == "": st.error("Please enter a valid SMILES string") st.stop() # start a spinner with st.spinner("Running DCM-Net..."): results = run_dcm(smiles, n_dcm) print(results) ase_atoms = results["atoms"] col1, col2 = st.columns(2) with col1: showmol(render_3d(ase_atoms), height=300, width=WIDTH) with col2: st.image(results["smiles_image"]) st.subheader("Model Output") if n_dcm >= 1: st.markdown("### DCM-1") dcmcolors1 = color_charges(results["mono_dc1"], 1, ase_atoms) # two columns showing combined and just the dcmol col1, col2 = st.columns(2) with col1: showmol(render_3d(ase_atoms, results["dcmol"]), height=300, width=WIDTH) with col2: showmol(render_3d(None, results["dcmol"], dcmcolors1), height=300, width=WIDTH) if n_dcm >= 2: st.markdown("### DCM-2") dcmcolors2 = color_charges(results["mono_dc2"], 2, ase_atoms) col1, col2 = st.columns(2) with col1: showmol(render_3d(ase_atoms, results["dcmol2"]), height=300, width=WIDTH) with col2: showmol(render_3d(None, results["dcmol2"], dcmcolors2), height=300, width=WIDTH) if n_dcm >= 3: st.markdown("### DCM-3") dcmcolors3 = color_charges(results["mono_dc3"], 3, ase_atoms) col1, col2 = st.columns(2) with col1: showmol(render_3d(ase_atoms, results["dcmol3"]), height=300, width=WIDTH) with col2: showmol(render_3d(None, results["dcmol3"], dcmcolors3), height=300, width=WIDTH) if n_dcm >= 4: st.markdown("### DCM-4") dcmcolors4 = color_charges(results["mono_dc4"], 4, ase_atoms) col1, col2 = st.columns(2) with col1: showmol(render_3d(ase_atoms, results["dcmol4"]), height=300, width=WIDTH) with col2: showmol(render_3d(None, results["dcmol4"], dcmcolors4), height=300, width=WIDTH)