| | import streamlit as st |
| | from streamlit_ketcher import st_ketcher |
| | from stmol import showmol |
| | import py3Dmol |
| | import io |
| | import ase.io as ase_io |
| | import numpy as np |
| | try: |
| | from StringIO import StringIO |
| | except ImportError: |
| | from io import StringIO |
| |
|
| | import streamlit.components.v1 as components |
| |
|
| |
|
| | import streamlit as st |
| | from streamlit_ketcher import st_ketcher |
| | WIDTH = 500 |
| | st.title("DCM-Net") |
| | smiles = "OC(=O)[C@@H]1CCCN1" |
| | |
| | smiles = st_ketcher(smiles) |
| |
|
| | |
| | st.write(f"SMILES input: {smiles}") |
| |
|
| |
|
| | |
| | def color_charges(mono, nDCM, atoms): |
| | from matplotlib import cm |
| | from matplotlib.colors import Normalize |
| |
|
| | norm = Normalize(vmin=-1, vmax=1) |
| | cmap = cm.get_cmap("bwr") |
| | mappable = cm.ScalarMappable(norm=norm, cmap=cmap) |
| | elem = atoms.numbers |
| | pccolors = mappable.to_rgba(mono.flatten()[: len(atoms.numbers)]) |
| | from ase.data.colors import jmol_colors |
| |
|
| | atomcolors = [jmol_colors[_] for _ in elem] |
| | atomcolors_ = [] |
| | for _ in atomcolors: |
| | atomcolors_.append(np.append(_, 0.015)) |
| | dcmcolors = mappable.to_rgba(mono.flatten()[: len(elem) * nDCM]) |
| | return dcmcolors |
| |
|
| |
|
| | from dcm_app import run_dcm |
| | |
| | def render_3d(atoms, dcm_charges=None, dcmcolors=None): |
| | view = py3Dmol.view() |
| |
|
| | view.startjs += '''\n |
| | let customColorize = function(atom){ |
| | // attribute elem is from https://3dmol.csb.pitt.edu/doc/AtomSpec.html#elem |
| | if (atom.elem === 'X'){ |
| | return "#FF0000" // red |
| | } |
| | else if (atom.elem === 'He'){ |
| | return "#0000FF" // blue |
| | } |
| | else{ |
| | return $3Dmol.getColorFromStyle(atom, {colorscheme: "whiteCarbon"}); |
| | } |
| | } |
| | \n''' |
| |
|
| | if atoms is not None: |
| | |
| | xyz = io.StringIO() |
| | ase_io.write(xyz, atoms, format='xyz') |
| | |
| | view.addModel(xyz.getvalue(), 'xyz') |
| | view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.25}}) |
| |
|
| | if dcm_charges is not None: |
| | xyz_dcm = io.StringIO() |
| | ase_io.write(xyz_dcm, dcm_charges, format='xyz') |
| | view.addModel(xyz_dcm.getvalue(), 'xyz') |
| | view.setStyle({'model': -1}, {'stick': {'radius': 0.01}, 'sphere': {'scale': 0.05}}) |
| | view.startjs = view.startjs.replace('"customColorize"', 'customColorize') |
| | view.zoomTo() |
| | return view |
| |
|
| | |
| | n_dcm = st.slider("Number of Distributed Charges per Atom", min_value=1, max_value=4, value=4) |
| |
|
| | |
| | if st.button("Run DCM-Net"): |
| | |
| | if smiles is None or smiles == "": |
| | st.error("Please enter a valid SMILES string") |
| | st.stop() |
| |
|
| | |
| | with st.spinner("Running DCM-Net..."): |
| | results = run_dcm(smiles, n_dcm) |
| | print(results) |
| | ase_atoms = results["atoms"] |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | showmol(render_3d(ase_atoms), height=300, width=WIDTH) |
| | with col2: |
| | st.image(results["smiles_image"]) |
| | st.subheader("Model Output") |
| | if n_dcm >= 1: |
| | st.markdown("### DCM-1") |
| | dcmcolors1 = color_charges(results["mono_dc1"], 1, ase_atoms) |
| | |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | showmol(render_3d(ase_atoms, results["dcmol"]), height=300, width=WIDTH) |
| | with col2: |
| | showmol(render_3d(None, results["dcmol"], dcmcolors1), height=300, width=WIDTH) |
| |
|
| | if n_dcm >= 2: |
| | st.markdown("### DCM-2") |
| | dcmcolors2 = color_charges(results["mono_dc2"], 2, ase_atoms) |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | showmol(render_3d(ase_atoms, results["dcmol2"]), height=300, width=WIDTH) |
| | with col2: |
| | showmol(render_3d(None, results["dcmol2"], dcmcolors2), height=300, width=WIDTH) |
| |
|
| | if n_dcm >= 3: |
| | st.markdown("### DCM-3") |
| | dcmcolors3 = color_charges(results["mono_dc3"], 3, ase_atoms) |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | showmol(render_3d(ase_atoms, results["dcmol3"]), height=300, width=WIDTH) |
| | with col2: |
| | showmol(render_3d(None, results["dcmol3"], dcmcolors3), height=300, width=WIDTH) |
| |
|
| | if n_dcm >= 4: |
| | st.markdown("### DCM-4") |
| | dcmcolors4 = color_charges(results["mono_dc4"], 4, ase_atoms) |
| | col1, col2 = st.columns(2) |
| | with col1: |
| | showmol(render_3d(ase_atoms, results["dcmol4"]), height=300, width=WIDTH) |
| | with col2: |
| | showmol(render_3d(None, results["dcmol4"], dcmcolors4), height=300, width=WIDTH) |
| |
|
| |
|