diff --git "a/src/streamlit_app.py" "b/src/streamlit_app.py" deleted file mode 100644--- "a/src/streamlit_app.py" +++ /dev/null @@ -1,2220 +0,0 @@ -import streamlit as st -import os -import io -import tempfile -import torch -# FOR CPU only mode -# torch._dynamo.config.suppress_errors = True -# Or disable compilation entirely -# torch.backends.cudnn.enabled = False -import numpy as np -from ase import Atoms -from ase.io import read, write -from ase.optimize import BFGS, LBFGS, FIRE, LBFGSLineSearch, BFGSLineSearch, GPMin, MDMin -from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG -from ase.optimize.basin import BasinHopping -from ase.optimize.minimahopping import MinimaHopping -from ase.units import kB -from ase.constraints import FixAtoms -from ase.filters import FrechetCellFilter -from ase.visualize import view -import py3Dmol -from mace.calculators import mace_mp -from fairchem.core import pretrained_mlip, FAIRChemCalculator -from orb_models.forcefield import pretrained -from orb_models.forcefield.calculator import ORBCalculator -from sevenn.calculator import SevenNetCalculator -import pandas as pd -import yaml # Added for FairChem reference energies -import subprocess -import sys -import pkg_resources -from ase.vibrations import Vibrations -from mp_api.client import MPRester -import pubchempy as pcp -from io import StringIO -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer -from pymatgen.io.ase import AseAtomsAdaptor -from pymatgen.core.structure import Molecule -import matplotlib.pyplot as plt -mattersim_available = True -if mattersim_available: - from mattersim.forcefield import MatterSimCalculator -# try: -# subprocess.check_call([sys.executable, "-m", "pip", "install", "mattersim"]) -# except Exception as e: -# print(f"Error during installation of mattersim: {e}") - -# try: -# from mattersim.forcefield import MatterSimCalculator -# mattersim_available = True -# print("\n\n\n\n\n\n\nSuccessfully imported MatterSimCalculator.\n\n\n\n\n\n\n\n\n\n") -# except ImportError as e: -# print(f"Failed to import MatterSimCalculator: {e} \n\n\n\n\n\n\n\n") -# mattersim_available = False -# # Define version threshold -# required_version = "2.0.0" - -# try: -# installed_version = pkg_resources.get_distribution("numpy").version -# if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version(required_version): -# print(f"numpy version {installed_version} >= {required_version}. Installing numpy<2.0.0...") -# subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0"]) -# else: -# print(f"numpy version {installed_version} is already < {required_version}. No action needed.") -# except pkg_resources.DistributionNotFound: -# print("numpy is not installed. Installing numpy<2.0.0...") -# subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0"]) - - -from huggingface_hub import login - -# try: -# hf_token = st.secrets["HF_TOKEN"]["token"] -# os.environ["HF_TOKEN"] = hf_token -# login(token=hf_token) -# except Exception as e: -# print("streamlit hf secret not defined/assigned") -try: - hf_token = os.getenv("YOUR SECRET KEY") # Replace with your actual Hugging Face token or manage secrets appropriately - if hf_token: - login(token = hf_token) - else: - print("Hugging Face token not found. Some models might not be accessible.") -except Exception as e: - print(f"hf login error: {e}") - - -os.environ["STREAMLIT_WATCHER_TYPE"] = "none" - -# YAML data for FairChem reference energies -ELEMENT_REF_ENERGIES_YAML = """ -oc20_elem_refs: -- 0.0 -- -0.16141512 -- 0.03262098 -- -0.04787699 -- -0.06299825 -- -0.14979306 -- -0.11657468 -- -0.10862579 -- -0.10298174 -- -0.03420248 -- 0.02673997 -- -0.03729558 -- 0.00515243 -- -0.07535697 -- -0.13663351 -- -0.12922852 -- -0.11796547 -- -0.07802946 -- -0.00672682 -- -0.04089589 -- -0.00024177 -- -1.74545186 -- -1.54220241 -- -1.0934019 -- -1.16168372 -- -1.23073475 -- -0.78852824 -- -0.71851599 -- -0.52465053 -- -0.02692092 -- -0.00317922 -- -0.06266862 -- -0.10835274 -- -0.12394474 -- -0.11351727 -- -0.07455817 -- -0.00258354 -- -0.04111325 -- -0.02090265 -- -1.89306078 -- -1.30591887 -- -0.63320009 -- -0.26230344 -- -0.2633669 -- -0.5160055 -- -0.95950798 -- -1.45589361 -- -0.0429969 -- -0.00026949 -- -0.05925609 -- -0.09734631 -- -0.12406852 -- -0.11427538 -- -0.07021442 -- 0.01091345 -- -0.05305289 -- -0.02427209 -- -0.19975668 -- -1.71692859 -- -1.53677781 -- -3.89987009 -- -10.70940462 -- -6.71693816 -- -0.28102249 -- -8.86944824 -- -7.95762687 -- -7.13041437 -- -6.64620014 -- -5.11482482 -- -4.42548227 -- 0.00848295 -- -0.06956227 -- -2.6748853 -- -2.21153293 -- -1.67367741 -- -1.07636151 -- -0.79009981 -- -0.16387243 -- -0.18164401 -- -0.04122529 -- -0.00041833 -- -0.05259382 -- -0.0934314 -- -0.11023834 -- -0.10039175 -- -0.06069209 -- 0.01790437 -- -0.04694024 -- 0.00334084 -- -0.06030621 -- -0.58793619 -- -1.27821808 -- -4.97483577 -- -5.66985655 -- -8.43154622 -- -11.15001317 -- -12.95770812 -- 0.0 -- -14.47602729 -- 0.0 -odac_elem_refs: -- 0.0 -- -1.11737936 -- -0.00011835 -- -0.2941727 -- -0.03868426 -- -0.34862832 -- -1.31552566 -- -3.12457285 -- -1.6052078 -- -0.49653389 -- -0.01137327 -- -0.21957281 -- -0.0008343 -- -0.2750172 -- -0.88417265 -- -1.887378 -- -0.94903558 -- -0.31628167 -- -0.02014536 -- -0.15901053 -- -0.00731884 -- -1.96521355 -- -1.89045209 -- -2.53057428 -- -5.43600675 -- -5.09739336 -- -3.03088746 -- -1.23786562 -- -0.40650749 -- -0.2416017 -- -0.01139188 -- -0.26282496 -- -0.82446455 -- -1.70237206 -- -0.84245376 -- -0.28544892 -- -0.02239991 -- -0.14115912 -- -0.02840799 -- -2.09540994 -- -1.85863996 -- -1.12257399 -- -4.32965355 -- -3.30670045 -- -1.19460755 -- -1.26257601 -- -1.46832888 -- -0.19779414 -- -0.0144274 -- -0.23668767 -- -0.70836953 -- -1.43186113 -- -0.71701186 -- -0.24883129 -- -0.01118184 -- -0.13173447 -- -0.0318395 -- -0.41195547 -- -1.23134873 -- -2.03082996 -- 0.1375954 -- -5.45866275 -- -7.59139905 -- -5.99965965 -- -8.43495767 -- -2.6578407 -- -7.77349787 -- -5.30762201 -- -5.15109657 -- -4.41466995 -- -0.02995219 -- -0.2544495 -- -3.23821202 -- -3.45887214 -- -4.53635003 -- -4.60979468 -- -2.90707964 -- -1.28286153 -- -0.57716664 -- -0.18337108 -- -0.01135944 -- -0.22045398 -- -0.66150479 -- -1.32506342 -- -0.66500178 -- -0.22643927 -- -0.00728197 -- -0.11208472 -- -0.00757856 -- -0.21798637 -- -0.91078787 -- -1.78187161 -- -3.89912261 -- -3.94192659 -- -7.59026042 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -omat_elem_refs: -- 0.0 -- -1.11700253 -- 0.00079886 -- -0.29731164 -- -0.04129868 -- -0.29106192 -- -1.27751531 -- -3.12342715 -- -1.54797136 -- -0.43969356 -- -0.01250908 -- -0.22855413 -- -0.00943179 -- -0.21707638 -- -0.82619133 -- -1.88667434 -- -0.89093583 -- -0.25816211 -- -0.02414768 -- -0.17662425 -- -0.02568319 -- -2.13001165 -- -2.38688845 -- -3.55934233 -- -5.44700879 -- -5.14749562 -- -3.30662847 -- -1.42167737 -- -0.63181379 -- -0.23449167 -- -0.01146636 -- -0.21291259 -- -0.77939897 -- -1.70148487 -- -0.78386705 -- -0.22690657 -- -0.02245409 -- -0.16092396 -- -0.02798717 -- -2.25685695 -- -2.23690495 -- -2.15347771 -- -4.60251809 -- -3.36416792 -- -2.23062607 -- -1.15550917 -- -1.47553527 -- -0.19918102 -- -0.01475888 -- -0.19767692 -- -0.68005773 -- -1.43073368 -- -0.65790462 -- -0.18915279 -- -0.01179476 -- -0.13507902 -- -0.03056979 -- -0.36017439 -- -0.86279246 -- -0.20573327 -- -0.2734463 -- -0.20046965 -- -0.25444338 -- -8.37972664 -- -9.58424928 -- -0.19466184 -- -0.24860115 -- -0.19531288 -- -0.15401392 -- -0.14577898 -- -0.19655747 -- -0.15645898 -- -3.49380556 -- -3.5317097 -- -4.57108006 -- -4.63425205 -- -2.88247063 -- -1.45679675 -- -0.50290184 -- -0.18521704 -- -0.01123956 -- -0.17483649 -- -0.63132037 -- -1.3248562 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- -0.24135757 -- -1.04601971 -- -2.04574044 -- -3.84544799 -- -7.28626119 -- -7.3136314 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -omol_elem_refs: -- 0.0 -- -13.44558 -- -78.82027 -- -203.32564 -- -398.94742 -- -670.75275 -- -1029.85403 -- -1485.54188 -- -2042.97832 -- -2714.24015 -- -3508.74317 -- -4415.24203 -- -5443.89712 -- -6594.61834 -- -7873.6878 -- -9285.6593 -- -10832.62132 -- -12520.66852 -- -14354.278 -- -16323.54671 -- -18436.47845 -- -20696.18244 -- -23110.5386 -- -25682.99429 -- -28418.37804 -- -31317.92317 -- -34383.42519 -- -37623.46835 -- -41039.92413 -- -44637.38634 -- -48417.14864 -- -52373.87849 -- -56512.76952 -- -60836.14871 -- -65344.28833 -- -70041.24251 -- -74929.56277 -- -653.64777 -- -833.31922 -- -1038.0281 -- -1273.96788 -- -1542.45481 -- -1850.74158 -- -2193.91654 -- -2577.18734 -- -3004.13604 -- -3477.52796 -- -3997.31825 -- -4563.75804 -- -5171.82293 -- -5828.85334 -- -6535.61529 -- -7291.54792 -- -8099.87914 -- -8962.17916 -- -546.03214 -- -690.6089 -- -854.11237 -- -12923.04096 -- -14064.26124 -- -15272.68689 -- -16550.20551 -- -17900.36515 -- -19323.23406 -- -20829.08848 -- -22428.73258 -- -24078.68008 -- -25794.42097 -- -27616.6819 -- -29523.5526 -- -31526.68012 -- -33615.37779 -- -1300.17791 -- -1544.40924 -- -1818.62298 -- -2123.14417 -- -2461.76028 -- -2833.76287 -- -3242.79895 -- -3690.363 -- -4174.99772 -- -4691.75674 -- -5245.36013 -- -5838.12005 -- -6469.07296 -- -7140.86455 -- -7854.60638 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -- 0.0 -omc_elem_refs: -- 0.0 -- -0.02831808 -- 4.512e-05 -- -0.03227157 -- -0.03842519 -- -0.05829283 -- -0.0845041 -- -0.08806738 -- -0.09021346 -- -0.06669846 -- -0.01218631 -- -0.03650269 -- -0.00059093 -- -0.05787736 -- -0.08730952 -- -0.0975534 -- -0.09264199 -- -0.07124762 -- -0.02374602 -- -0.05299112 -- -0.02631476 -- -1.7772147 -- -1.25083444 -- -0.79579447 -- -0.49099317 -- -0.31414986 -- -0.20292182 -- -0.14011632 -- -0.09929659 -- -0.03771207 -- -0.01117902 -- -0.06168715 -- -0.08873364 -- -0.09512942 -- -0.09035978 -- -0.06910849 -- -0.02244872 -- -0.05303651 -- -0.02871903 -- -1.94805417 -- -1.33379896 -- -0.69169331 -- -0.26184306 -- -0.20631599 -- -0.48251608 -- -0.96911893 -- -1.47569462 -- -0.03845194 -- -0.0142445 -- -0.07118991 -- -0.09940292 -- -0.09235056 -- -0.08755943 -- -0.06544925 -- -0.01246646 -- -0.04692937 -- -0.03225123 -- -0.26086039 -- -27.20024339 -- -0.08412926 -- -0.08225924 -- -0.07799715 -- -0.07806185 -- 0.00043759 -- -0.07459766 -- 0.0 -- -0.06842841 -- -0.07758266 -- -0.07025152 -- -0.08055003 -- -0.07118177 -- -0.07159568 -- -2.69202862 -- -2.21926765 -- -1.679756 -- -1.06135075 -- -0.4554231 -- -0.14488432 -- -0.18377098 -- -0.03603118 -- -0.01076585 -- -0.06381411 -- -0.0905623 -- -0.10095787 -- -0.09501217 -- -0.0574478 -- -0.00599173 -- -0.04134751 -- -0.0082683 -- -0.08704692 -- -0.49656425 -- -5.24233138 -- -2.32542606 -- -4.3376616 -- -5.96430676 -- 0.0 -- 0.0 -- -0.03842519 -- 0.0 -- 0.0 -""" -try: - ELEMENT_REF_ENERGIES = yaml.safe_load(ELEMENT_REF_ENERGIES_YAML) -except yaml.YAMLError as e: - # st.error(f"Error parsing YAML reference energies: {e}") # st objects can only be used in main script flow - print(f"Error parsing YAML reference energies: {e}") - ELEMENT_REF_ENERGIES = {} # Fallback - -# Check if running on Streamlit Cloud vs locally -is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud' -MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud -MAX_ATOMS_CLOUD_UMA = 500 - -# Set page configuration -st.set_page_config( - page_title="MLIP Playground - Run, Test and Benchmark MLIPs", - page_icon="🧪", - layout="wide" -) - -# Title and description -st.markdown('## MLIP Playground', unsafe_allow_html=True) -st.write('#### Run, test and compare 42 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials') -st.markdown('Upload molecular structure files or select from predefined examples, then compute energies and forces using foundation models such as those from MACE or FairChem (Meta).', unsafe_allow_html=True) - -# Create a directory for sample structures if it doesn't exist -SAMPLE_DIR = "sample_structures" -os.makedirs(SAMPLE_DIR, exist_ok=True) - -# Dictionary of sample structures -SAMPLE_STRUCTURES = { - "Water": "H2O.xyz", - "Methane": "CH4.xyz", - "Benzene": "C6H6.xyz", - "Ethane": "C2H6.xyz", - "Caffeine": "caffeine.xyz", - "Ibuprofen": "ibuprofen.xyz", - "Silicon": "Si.cif", - "hBN Monolayer (4x4)": "hBN_monolayer_4x4_supercell.extxyz", -} - -def get_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400, - show_path=True, path_color='red', path_radius=0.02): - """ - Visualize optimization trajectory with multiple frames - - Args: - trajectory: List of ASE atoms objects representing the optimization steps - style: Visualization style ('stick', 'ball', 'ball-stick') - show_unit_cell: Whether to show unit cell - show_path: Whether to show trajectory paths for each atom - path_color: Color of trajectory paths - path_radius: Radius of trajectory path cylinders - """ - if not trajectory: - return None - - view = py3Dmol.view(width=width, height=height) - - # Add all frames to the viewer - for frame_idx, atoms_obj in enumerate(trajectory): - xyz_str = "" - xyz_str += f"{len(atoms_obj)}\n" - xyz_str += f"Frame {frame_idx}\n" - for atom in atoms_obj: - xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" - - view.addModel(xyz_str, "xyz") - - # Set style for all models - if style.lower() == 'ball-stick': - view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}}) - elif style.lower() == 'stick': - view.setStyle({'stick': {}}) - elif style.lower() == 'ball': - view.setStyle({'sphere': {'scale': 0.4}}) - else: - view.setStyle({'stick': {'radius': 0.15}}) - - # Add trajectory paths - if show_path and len(trajectory) > 1: - for atom_idx in range(len(trajectory[0])): - for frame_idx in range(len(trajectory) - 1): - start_pos = trajectory[frame_idx][atom_idx].position - end_pos = trajectory[frame_idx + 1][atom_idx].position - - view.addCylinder({ - 'start': {'x': start_pos[0], 'y': start_pos[1], 'z': start_pos[2]}, - 'end': {'x': end_pos[0], 'y': end_pos[1], 'z': end_pos[2]}, - 'radius': path_radius, - 'color': path_color, - 'alpha': 0.5 - }) - - # Add unit cell for the last frame - if show_unit_cell and trajectory[-1].pbc.any(): - cell = trajectory[-1].get_cell() - origin = np.array([0.0, 0.0, 0.0]) - if cell is not None and cell.any(): - edges = [ - (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]), - (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]), - (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]), - (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]), - (cell[0] + cell[1], cell[0] + cell[1] + cell[2]) - ] - for start, end in edges: - view.addCylinder({ - 'start': {'x': start[0], 'y': start[1], 'z': start[2]}, - 'end': {'x': end[0], 'y': end[1], 'z': end[2]}, - 'radius': 0.05, 'color': 'black', 'alpha': 0.7 - }) - - view.zoomTo() - view.setBackgroundColor('white') - return view - - -def get_animated_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400): - """ - Create an animated trajectory visualization - """ - if not trajectory: - return None - - view = py3Dmol.view(width=width, height=height) - - # Add all frames - for frame_idx, atoms_obj in enumerate(trajectory): - xyz_str = "" - xyz_str += f"{len(atoms_obj)}\n" - xyz_str += f"Frame {frame_idx}\n" - for atom in atoms_obj: - xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" - - view.addModel(xyz_str, "xyz") - - # Set style - if style.lower() == 'ball-stick': - view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}}) - elif style.lower() == 'stick': - view.setStyle({'stick': {}}) - elif style.lower() == 'ball': - view.setStyle({'sphere': {'scale': 0.4}}) - else: - view.setStyle({'stick': {'radius': 0.15}}) - - # Add unit cell for last frame - if show_unit_cell and trajectory[-1].pbc.any(): - cell = trajectory[-1].get_cell() - origin = np.array([0.0, 0.0, 0.0]) - if cell is not None and cell.any(): - edges = [ - (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]), - (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]), - (cell[0] + cell[1], cell[0] + cell[1] + cell[2]), - (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]), - (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]) - ] - for start, end in edges: - view.addCylinder({ - 'start': {'x': start[0], 'y': start[1], 'z': start[2]}, - 'end': {'x': end[0], 'y': end[1], 'z': end[2]}, - 'radius': 0.05, 'color': 'black', 'alpha': 0.7 - }) - - view.zoomTo() - view.setBackgroundColor('white') - - # Enable animation - view.animate({'loop': 'forward', 'reps': 0, 'interval': 500}) - - return view - - -# Streamlit implementation example -def display_optimization_trajectory(trajectory, viz_style='ball-stick'): - """ - Display optimization trajectory in Streamlit with controls - """ - if not trajectory: - st.error("No trajectory data available") - return - - st.subheader(f"Optimization Trajectory ({len(trajectory)} steps)") - - # Trajectory options - col1, col2 = st.columns(2) - - with col1: - viz_mode = st.selectbox( - "Visualization Mode", - ["Animation", "Static with paths", "Step-by-step"], - key="viz_mode" - ) - - with col2: - if viz_mode == "Static with paths": - show_paths = st.checkbox("Show trajectory paths", value=True) - path_color = st.selectbox("Path color", ["red", "blue", "green", "orange"], index=0) - elif viz_mode == "Step-by-step": - frame_idx = st.slider("Frame", 0, len(trajectory)-1, 0, key="frame_slider") - - # Display visualization based on mode - if viz_mode == "Static with paths": - opt_view = get_trajectory_viz( - trajectory, - style=viz_style, - show_unit_cell=True, - width=400, - height=400, - show_path=show_paths, - path_color=path_color - ) - st.components.v1.html(opt_view._make_html(), width=400, height=400) - - elif viz_mode == "Animation": - opt_view = get_animated_trajectory_viz( - trajectory, - style=viz_style, - show_unit_cell=True, - width=400, - height=400 - ) - st.components.v1.html(opt_view._make_html(), width=400, height=400) - - elif viz_mode == "Step-by-step": - opt_view = get_structure_viz2( - trajectory[frame_idx], - style=viz_style, - show_unit_cell=True, - width=400, - height=400 - ) - st.components.v1.html(opt_view._make_html(), width=400, height=400) - st.write(f"Step {frame_idx + 1} of {len(trajectory)}") - -def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400): - xyz_str = "" - xyz_str += f"{len(atoms_obj)}\n" - xyz_str += "Structure\n" - for atom in atoms_obj: - xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" - - view = py3Dmol.view(width=width, height=height) - view.addModel(xyz_str, "xyz") - - if style.lower() == 'ball-stick': - view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}}) - elif style.lower() == 'stick': - view.setStyle({'stick': {}}) - elif style.lower() == 'ball': - view.setStyle({'sphere': {'scale': 0.4}}) - else: - view.setStyle({'stick': {'radius': 0.15}}) - - if show_unit_cell and atoms_obj.pbc.any(): # Check pbc.any() - cell = atoms_obj.get_cell() - origin = np.array([0.0, 0.0, 0.0]) - if cell is not None and cell.any(): # Ensure cell is not None and not all zeros - edges = [ - (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]), - (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]), - (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]), - (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]), - (cell[0] + cell[1], cell[0] + cell[1] + cell[2]) - ] - for start, end in edges: - view.addCylinder({ - 'start': {'x': start[0], 'y': start[1], 'z': start[2]}, - 'end': {'x': end[0], 'y': end[1], 'z': end[2]}, - 'radius': 0.05, 'color': 'black', 'alpha': 0.7 - }) - view.zoomTo() - view.setBackgroundColor('white') - return view - -opt_log = [] # Define globally or pass around if necessary -table_placeholder = st.empty() # Define globally if updated from callback - -def streamlit_log(opt): - global opt_log, table_placeholder - try: - energy = opt.atoms.get_potential_energy() - forces = opt.atoms.get_forces() - fmax_step = np.max(np.linalg.norm(forces, axis=1)) if forces.shape[0] > 0 else 0.0 - opt_log.append({ - "Step": opt.nsteps, - "Energy (eV)": round(energy, 6), - "Fmax (eV/Å)": round(fmax_step, 6) - }) - df = pd.DataFrame(opt_log) - table_placeholder.dataframe(df) - except Exception as e: - st.warning(f"Error in optimization logger: {e}") - - -def check_atom_limit(atoms_obj, selected_model): - if atoms_obj is None: - return True - num_atoms = len(atoms_obj) - limit = MAX_ATOMS_CLOUD_UMA if ('UMA' in selected_model or 'ESEN MD' in selected_model) else MAX_ATOMS_CLOUD - if num_atoms > limit: - st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, exceeding the {limit} atom limit for this model on Streamlit Cloud. Please run locally for larger systems.") - return False - return True - -MACE_MODELS = { - "MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", - "MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", - "MACE OMAT Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-small.model", - "MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", - "MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", - "MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", - "MACE MP 0a Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", - "MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model", - "MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model", - "MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model", - "MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", # Corrected name from original code - "MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", - "MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", - "MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", - "MACE ANI-CC Large (500k)": "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model", - "MACE OMOL-0 XL 4M": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_omol_0/mace-omol-0-extra-large-4M.model", - "MACE OMOL-0 XL 1024": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_omol_0/MACE-omol-0-extra-large-1024.model", - "MACE OFF 23 Large": "https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_large.model", - "MACE OFF 23 Medium": "https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_medium.model", - "MACE OFF 23 Small": "https://github.com/ACEsuit/mace-off/raw/main/mace_off23/MACE-OFF23_small.model", - "MACE OFF 24 Medium": "https://github.com/ACEsuit/mace-off/raw/main/mace_off24/MACE-OFF24_medium.model" -} - -FAIRCHEM_MODELS = { - "UMA Small 1": "uma-s-1", - "UMA Small 1.1": "uma-s-1p1", - "ESEN MD Direct All OMOL": "esen-md-direct-all-omol", - "ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol", - "ESEN SM Direct All OMOL": "esen-sm-direct-all-omol" -} -# Define the available ORB models -ORB_MODELS = { - "V3 OMOL Conservative": pretrained.orb_v3_conservative_omol, - "V3 OMOL Direct": pretrained.orb_v3_direct_omol, - "V3 OMAT Conservative (inf)": pretrained.orb_v3_conservative_inf_omat, - "V3 OMAT Conservative (20)": pretrained.orb_v3_conservative_20_omat, - "V3 OMAT Direct (inf)": pretrained.orb_v3_direct_inf_omat, - "V3 OMAT Direct (20)": pretrained.orb_v3_direct_20_omat, - "V3 MPA Conservative (inf)": pretrained.orb_v3_conservative_inf_mpa, - "V3 MPA Conservative (20)": pretrained.orb_v3_conservative_20_mpa, - "V3 MPA Direct (inf)": pretrained.orb_v3_direct_inf_mpa, - "V3 MPA Direct (20)": pretrained.orb_v3_direct_20_mpa, -} -# Define the available MatterSim models -MATTERSIM_MODELS = { - "V1 SMALL": "MatterSim-v1.0.0-1M.pth", - "V1 LARGE": "MatterSim-v1.0.0-5M.pth" -} -SEVEN_NET_MODELS = { - "7net-0": "7net-0", - "7net-l3i5": "7net-l3i5", - "7net-omat": "7net-omat", - "7net-mf-ompa": "7net-mf-ompa" -} -@st.cache_resource -def get_mace_model(model_path, dispersion, device, selected_default_dtype): - return mace_mp(model=model_path, dispersion=dispersion, device=device, default_dtype=selected_default_dtype) - -@st.cache_resource -def get_fairchem_model(selected_model_name, model_path_or_name, device, selected_task_type_fc): # Renamed args to avoid conflict - predictor = pretrained_mlip.get_predict_unit(model_path_or_name, device=device) - if "UMA Small" in selected_model_name: - calc = FAIRChemCalculator(predictor, task_name=selected_task_type_fc) - else: - calc = FAIRChemCalculator(predictor, task_name="omol") - return calc - -# --- INITIALIZATION (Must be run first) --- -if "atoms" not in st.session_state: - st.session_state.atoms = None -if "atoms_list" not in st.session_state: - st.session_state.atoms_list = [] - -# Reset atoms state if input method changes, to prevent using old data -# Use a key to track the currently active input method -if 'current_input_method' not in st.session_state: - st.session_state.current_input_method = "Select Example" - -st.sidebar.markdown("## Input Options") -input_method = st.sidebar.radio("Choose Input Method:", - ["Select Example", "Upload File", "Paste Content", "Materials Project ID", "PubChem", "Batch Upload"]) - -# If the input method changes, clear the loaded structure -if input_method != st.session_state.current_input_method: - st.session_state.atoms = None - st.session_state.current_input_method = input_method - -# --- UPLOAD FILE --- -if input_method == "Upload File": - uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"]) - - # Load immediately upon file upload/change (no button needed) - if uploaded_file: - try: - # Check if this file content has already been loaded to prevent redundant temp file operations - if 'uploaded_file_hash' not in st.session_state or st.session_state.uploaded_file_hash != uploaded_file.name: - - # Use tempfile to handle the uploaded file content - with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: - tmp_file.write(uploaded_file.getvalue()) - tmp_filepath = tmp_file.name - - atoms_to_store = read(tmp_filepath) - st.session_state.atoms = atoms_to_store - st.session_state.uploaded_file_hash = uploaded_file.name # Track the loaded file - st.sidebar.success(f"Successfully loaded structure with {len(atoms_to_store)} atoms!") - - except Exception as e: - st.sidebar.error(f"Error loading file: {str(e)}") - st.session_state.atoms = None - st.session_state.uploaded_file_hash = None # Clear hash on failure - finally: - # Clean up the temporary file - if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath): - os.unlink(tmp_filepath) - else: - # Clear structure if file uploader is empty - st.session_state.atoms = None - -# --- SELECT EXAMPLE --- -elif input_method == "Select Example": - # Load immediately upon selection change (no button needed) - example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys())) - - # Only load if a valid example is selected and it's different from the current state - if example_name and (st.session_state.atoms is None or st.session_state.atoms.info.get('source_name') != example_name): - file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name]) - try: - atoms_to_store = read(file_path) - atoms_to_store.info['source_name'] = example_name # Add a tag for tracking - st.session_state.atoms = atoms_to_store - st.sidebar.success(f"Loaded {example_name} with {len(atoms_to_store)} atoms!") - except Exception as e: - st.sidebar.error(f"Error loading example: {str(e)}") - st.session_state.atoms = None - -# --- PASTE CONTENT --- -elif input_method == "Paste Content": - file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"]) - content = st.sidebar.text_area("Paste file content here:", height=200, key="paste_content_input") - - # Load immediately upon content change (no button needed) - # Check if content is present and is different from the last successfully parsed content - if content: - # Simple check to avoid parsing on every single character change - if 'last_parsed_content' not in st.session_state or st.session_state.last_parsed_content != content: - - try: - suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"} - suffix = suffix_map.get(file_format, ".xyz") - - with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: - tmp_file.write(content.encode()) - tmp_filepath = tmp_file.name - - atoms_to_store = read(tmp_filepath) - st.session_state.atoms = atoms_to_store - st.session_state.last_parsed_content = content # Track the parsed content - st.sidebar.success(f"Successfully parsed structure with {len(atoms_to_store)} atoms!") - except Exception as e: - st.sidebar.error(f"Error parsing content: {str(e)}") - st.session_state.atoms = None - st.session_state.last_parsed_content = None - finally: - if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath): - os.unlink(tmp_filepath) - else: - # Clear structure if text area is empty - st.session_state.atoms = None - -# --- PUBCHEM SEARCH MODE --- -elif input_method == "PubChem": - - - st.sidebar.markdown("### Search PubChem") - - query = st.sidebar.text_input("Enter name or formula (e.g., H2O, water, methane):", - key="pubchem_query", value="water") - - # Reset atoms if no query - if query.strip() == "": - st.session_state.atoms = None - - # Step 1: Search PubChem - if query and query.strip(): - # Avoid re-searching if query is unchanged - if "pubchem_last_query" not in st.session_state or st.session_state.pubchem_last_query != query: - try: - with st.spinner("Searching PubChem..."): - results = pcp.get_compounds(query, "name") # name OR formula works - st.session_state.pubchem_results = results - st.session_state.pubchem_last_query = query - except Exception as e: - st.sidebar.error(f"Error searching PubChem: {str(e)}") - st.session_state.pubchem_results = None - - results = st.session_state.get("pubchem_results", []) - if results: - # Convert to displayable table - df = pd.DataFrame( - [(c.cid, c.iupac_name, c.molecular_formula, c.molecular_weight, c.isomeric_smiles) - for c in results], - columns=["CID", "Name", "Formula", "Weight", "SMILES"] - ) - st.sidebar.success(f"Found {len(df)} result(s).") - st.sidebar.dataframe(df) - - # Choose a CID - cid = st.sidebar.selectbox("Select CID", df["CID"], key="pubchem_cid") - - # Step 2: Retrieve 3D structure for selected CID - if cid: - if "pubchem_last_cid" not in st.session_state or st.session_state.pubchem_last_cid != cid: - try: - with st.spinner("Fetching 3D coordinates..."): - # Function to format floating-point numbers with alignment - def format_number(num, width=10, precision=5): - # Handles positive/negative numbers while maintaining alignment - return f"{num: {width}.{precision}f}" - # CID to XYZ - def generate_xyz_coordinates(cid): - compound = pcp.Compound.from_cid(cid, record_type='3d') - atoms = compound.atoms - coords = [(atom.x, atom.y, atom.z) for atom in atoms] - - num_atoms = len(atoms) - xyz_text = f"{num_atoms}\n{compound.cid}\n" - - for atom, coord in zip(atoms, coords): - atom_symbol = atom.element - x, y, z = coord - xyz_text += f"{atom_symbol} {format_number(x, precision=8)} {format_number(y, precision=8)} {format_number(z, precision=8)}\n" - - return xyz_text - def get_molecule(cid): - xyz_str = generate_xyz_coordinates(cid) - return Molecule.from_str(xyz_str, fmt='xyz'), xyz_str - # Fetch SDF with 3D conformer - # sdf_str = pcp.Compound.from_cid(int(cid)).to_sdf() - selected_molecule, xyz_str = get_molecule(cid) - - # Convert SDF → ASE Atoms using temporary memory buffer - atoms_to_store = read(StringIO(xyz_str), format="xyz") - - atoms_to_store.info["source_name"] = f"PubChem CID {cid}" - st.session_state.atoms = atoms_to_store - st.session_state.pubchem_last_cid = cid - - st.sidebar.success(f"Loaded PubChem structure with {len(atoms_to_store)} atoms!") - - except Exception as e: - st.sidebar.error(f"Unable to retrieve 3D structure: {str(e)}") - st.session_state.atoms = None - st.session_state.pubchem_last_cid = None - else: - st.sidebar.info("No PubChem results found.") - -# --- MATERIALS PROJECT ID --- -elif input_method == "Materials Project ID": - mp_api_key = os.getenv("MP_API_KEY") - material_id = st.sidebar.text_input("Enter Material ID:", value="mp-149", key="mp_id_input") - cell_type = st.sidebar.radio("Unit Cell Type:", ['Primitive Cell', 'Conventional Unit Cell'], key="cell_type_radio") - - # Reactive Loading (No button needed) - # Check for valid inputs and if the current material_id/cell_type is different from the loaded one - if mp_api_key and material_id: - - # Simple tracking to avoid API call if nothing has changed - current_mp_key = f"{material_id}_{cell_type}" - if 'last_fetched_mp_key' not in st.session_state or st.session_state.last_fetched_mp_key != current_mp_key: - - try: - with st.spinner(f"Fetching {material_id}..."): - with MPRester(mp_api_key) as mpr: - pmg_structure = mpr.get_structure_by_material_id(material_id) - analyzer = SpacegroupAnalyzer(pmg_structure) - - if cell_type == 'Conventional Unit Cell': - final_structure = analyzer.get_conventional_standard_structure() - else: - final_structure = analyzer.get_primitive_standard_structure() - - atoms_to_store = AseAtomsAdaptor.get_atoms(final_structure) - st.session_state.atoms = atoms_to_store - st.session_state.last_fetched_mp_key = current_mp_key # Update tracking key - st.sidebar.success(f"Loaded {material_id} ({cell_type}) with {len(st.session_state.atoms)} atoms.") - - except Exception as e: - st.sidebar.error(f"Error fetching data: {str(e)}") - st.session_state.atoms = None - st.session_state.last_fetched_mp_key = None # Clear key on failure - - # Handle error messages when inputs are missing - elif not mp_api_key: - st.sidebar.error("Please set your Materials Project API Key (MP_API_KEY environment variable).") - elif not material_id: - st.sidebar.error("Please enter a Material ID.") - -# --- BATCH UPLOAD MULTIPLE FILES --- -elif input_method == "Batch Upload": - - uploaded_files = st.sidebar.file_uploader( - "Upload multiple structure files", - type=["xyz", "cif", "POSCAR", "vasp", "CONTCAR", "mol", "sdf", "tmol", "extxyz"], - accept_multiple_files=True - ) - - # Clear state if no files present - if not uploaded_files: - st.session_state.atoms_list = [] - st.session_state.atoms = None - - else: - atoms_list = [] - errors = [] - - for file in uploaded_files: - try: - with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[1]) as tmp: - tmp.write(file.getvalue()) - tmp_path = tmp.name - - atoms_obj = read(tmp_path) - atoms_obj.info["source_name"] = file.name - atoms_list.append(atoms_obj) - - except Exception as e: - errors.append(f"{file.name}: {str(e)}") - - finally: - if "tmp_path" in locals() and os.path.exists(tmp_path): - os.unlink(tmp_path) - - # Store everything only if at least one success - if atoms_list: - st.session_state.atoms_list = atoms_list - st.session_state.atoms = atoms_list[0] # default: first item - st.sidebar.success(f"Loaded {len(atoms_list)} structures successfully!") - - if len(atoms_list) > 1: - st.sidebar.info("You can now process them as a batch.") - - if errors: - st.sidebar.error("Some files could not be loaded:\n" + "\n".join(errors)) -# ---------------------------------------------------- -# --- FINAL STRUCTURE RETRIEVAL (The persistent structure) --- -# ---------------------------------------------------- -# This is the single source of truth for the rest of your app -atoms = st.session_state.atoms - -if atoms is not None: - if not hasattr(atoms, 'info'): - atoms.info = {} - atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge - atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1) - - # Display confirmation in the main area (optional, helps the user confirm what's loaded) - # st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)") - -st.sidebar.markdown("## Model Selection") -if mattersim_available: - model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET", "MatterSim"]) -else: - model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET"]) - -selected_task_type = None # For FairChem UMA -if model_type == "MACE": - selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys())) - model_path = MACE_MODELS[selected_model] - if selected_model == "MACE OMAT Medium": - st.sidebar.warning("Using model under Academic Software License (ASL).") - # selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64']) - selected_default_dtype = 'float64' - dispersion = st.sidebar.checkbox("Dispersion correction?", value=False) - if selected_model == "MACE OMOL-0 XL 4M" or selected_model == "MACE OMOL-0 XL 1024": - - charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0)) - spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=1, value=int(atoms.info.get("spin",0) if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S - atoms.info["charge"] = charge - atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity - # else: - # atoms.info["charge"] = 0 - # atoms.info["spin"] = 1 # FairChem expects multiplicity -if model_type == "FairChem": - selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys())) - model_path = FAIRCHEM_MODELS[selected_model] - if "UMA Small" in selected_model: - st.sidebar.warning("Meta FAIR Acceptable Use Policy applies.") - selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"]) - if selected_task_type == "omol" and atoms is not None: - charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0)) - spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=1, value=int(atoms.info.get("spin",0) if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S - atoms.info["charge"] = charge - atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity - else: - atoms.info["charge"] = 0 - atoms.info["spin"] = 1 # FairChem expects multiplicity -if model_type == "ORB": - selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys())) - model_path = ORB_MODELS[selected_model] - if "omat" in selected_model: - st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.") - # selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64']) - selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest']) -if model_type == "MatterSim": - selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys())) - model_path = MATTERSIM_MODELS[selected_model] -if model_type == "SEVEN_NET": - selected_model = st.sidebar.selectbox("Select SEVENNET Model:", list(SEVEN_NET_MODELS.keys())) - if selected_model == '7net-mf-ompa': - selected_modal_7net = st.sidebar.selectbox("Select Modal (multi fidelity model):", ['omat24', 'mpa']) - model_path = SEVEN_NET_MODELS[selected_model] -if atoms is not None: - if not check_atom_limit(atoms, selected_model): - st.stop() # Stop execution if limit exceeded - -device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"], index=0 if not torch.cuda.is_available() else 1) -device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu" - -if device == "cpu" and torch.cuda.is_available(): - st.sidebar.info("GPU is available but CPU was selected.") -elif device == "cpu" and not torch.cuda.is_available(): - st.sidebar.info("No GPU detected. Using CPU.") - -st.sidebar.markdown("## Task Selection") -task = st.sidebar.selectbox("Select Calculation Task:", - ["Energy Calculation", - "Energy + Forces Calculation", - "Atomization/Cohesive Energy", - "Geometry Optimization", - "Cell + Geometry Optimization", - #"Global Optimization", - "Vibrational Mode Analysis", - #"Phonons" - ]) - -if "Optimization" in task: - # st.sidebar.markdown("### Optimization Parameters") - # max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps - # fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax - # optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type - st.sidebar.markdown("### Optimization Parameters") - - # 1. Configuration for GLOBAL Optimization - if task == "Global Optimization": - global_method = st.sidebar.selectbox("Method:", ["Basin Hopping", "Minima Hopping"]) - - # Common parameters - temperature_K = st.sidebar.number_input("Temperature (K):", min_value=10.0, max_value=2000.0, value=300.0, step=10.0) - global_steps = st.sidebar.number_input("Search Steps:", min_value=10, max_value=500, value=50, step=10) - # Basin Hopping specific - if global_method == "Basin Hopping": - dr_amp = st.sidebar.number_input("Displacement Amplitude (Å):", min_value=0.1, max_value=2.0, value=0.7, step=0.1) - fmax_local = st.sidebar.number_input("Local Relaxation Threshold (eV/Å):", value=0.05, format="%.3f") - - # Minima Hopping specific - elif global_method == "Minima Hopping": - st.sidebar.caption("Minima Hopping automates threshold adjustments to escape local minima.") - fmax_local = st.sidebar.number_input("Local Relaxation Threshold (eV/Å):", value=0.05, format="%.3f") - - # 2. Configuration for LOCAL/CELL Optimization - else: - max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) - fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") - # optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) - optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "BFGSLineSearch", "LBFGS", "LBFGSLineSearch", "FIRE", "GPMin", "MDMin"], index=1) - -if "Vibration" in task: - st.write("### Thermodynamic Quantities (Molecule Only)") - T = st.sidebar.number_input("Temperature (K)", value=298.15) - -if atoms is not None: - col1, col2 = st.columns(2) - - with col1: - st.markdown('### Structure Visualization', unsafe_allow_html=True) - viz_style = st.selectbox("Select Visualization Style:", - ["ball-stick", - "stick", - "ball"]) - view_3d = get_structure_viz2(atoms, style=viz_style, show_unit_cell=True, width=400, height=400) - st.components.v1.html(view_3d._make_html(), width=400, height=400) - - st.markdown("### Structure Information") - atoms_info = { - "Number of Atoms": len(atoms), - "Chemical Formula": atoms.get_chemical_formula(), - "Periodic Boundary Conditions (PBC)": atoms.pbc.tolist(), - "Cell Dimensions": np.round(atoms.cell.cellpar(),3).tolist() if atoms.pbc.any() and atoms.cell is not None and atoms.cell.any() else "No cell / Non-periodic", - "Atom Types": ", ".join(sorted(list(set(atoms.get_chemical_symbols())))) - } - for key, value in atoms_info.items(): - st.write(f"**{key}:** {value}") - - with col2: - st.markdown('## Calculation Setup', unsafe_allow_html=True) - st.markdown("### Selected Model") - st.write(f"**Model Type:** {model_type}") - st.write(f"**Model:** {selected_model}") - if model_type == "FairChem" and "UMA Small" in selected_model: - st.write(f"**UMA Task Type:** {selected_task_type}") - if model_type == "MACE": - st.write(f"**Dispersion:** {dispersion}") - st.write(f"**Device:** {device}") - - st.markdown("### Selected Task") - st.write(f"**Task:** {task}") - - if "Geometry Optimization" in task: - st.write(f"**Max Steps:** {max_steps}") - st.write(f"**Convergence Threshold:** {fmax} eV/Å") - st.write(f"**Optimizer:** {optimizer_type}") - - run_calculation = st.button("Run Calculation", type="primary") - - if run_calculation: - # Delete all the items in Session state - for key in st.session_state.keys(): - del st.session_state[key] - results = {} - #global table_placeholder # Ensure they are accessible - opt_log = [] # Reset log for each run - if "Optimization" in task: - table_placeholder = st.empty() # Recreate placeholder for table - - try: - torch.set_default_dtype(torch.float32) - with st.spinner("Running calculation... Please wait."): - calc_atoms = atoms.copy() - - if model_type == "MACE": - # st.write("Setting up MACE calculator...") - calc = get_mace_model(model_path, dispersion, device, 'float32') - elif model_type == "FairChem": # FairChem - # st.write("Setting up FairChem calculator...") - # Workaround for potential dtype issues when switching models - # if device == "cpu": # Ensure torch default dtype matches if needed - # torch.set_default_dtype(torch.float32) - # _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call - calc = get_fairchem_model(selected_model, model_path, device, selected_task_type) - elif model_type == "ORB": - # st.write("Setting up ORB calculator...") - # orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype) - orbff = model_path(device=device, precision=selected_default_dtype) - calc = ORBCalculator(orbff, device=device) - elif model_type == "MatterSim": - # st.write("Setting up MatterSim calculator...") - # NOTE: Running mattersim on windows requires changing source code file - # https://github.com/microsoft/mattersim/issues/112 - # mattersim/datasets/utils/convertor.py: 117 - # to pbc_ = np.array(structure.pbc, dtype=np.int64) - calc = MatterSimCalculator(load_path=model_path, device=device) - elif model_type == "SEVEN_NET": - # st.write("Setting up SEVENNET calculator...") - if model_path=='7net-mf-ompa': - calc = SevenNetCalculator(model=model_path, modal=selected_modal_7net, device=device) - else: - calc = SevenNetCalculator(model=model_path, device=device) - - - calc_atoms.calc = calc - - if task == "Energy Calculation": - energy = calc_atoms.get_potential_energy() - results["Energy"] = f"{energy:.6f} eV" - - elif task == "Energy + Forces Calculation": - energy = calc_atoms.get_potential_energy() - forces = calc_atoms.get_forces() - max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0 - results["Energy"] = f"{energy:.6f} eV" - results["Maximum Force"] = f"{max_force:.6f} eV/Å" - - elif task == "Atomization/Cohesive Energy": - st.write("Calculating system energy...") - E_system = calc_atoms.get_potential_energy() - num_atoms = len(calc_atoms) - - if num_atoms == 0: - st.error("Cannot calculate atomization/cohesive energy for a system with zero atoms.") - results["Error"] = "System has no atoms." - else: - atomic_numbers = calc_atoms.get_atomic_numbers() - E_isolated_atoms_total = 0.0 - calculation_possible = True - - if model_type == "FairChem": - st.write("Fetching FairChem reference energies for isolated atoms...") - ref_key_suffix = "_elem_refs" - chosen_ref_list_name = None - if "UMA Small" in selected_model: - if selected_task_type: - chosen_ref_list_name = selected_task_type + ref_key_suffix - elif "ESEN" in selected_model: - chosen_ref_list_name = "omol" + ref_key_suffix - - if chosen_ref_list_name and chosen_ref_list_name in ELEMENT_REF_ENERGIES: - ref_energies = ELEMENT_REF_ENERGIES[chosen_ref_list_name] - missing_Z_refs = [] - for Z_val in atomic_numbers: - if Z_val > 0 and Z_val < len(ref_energies): - E_isolated_atoms_total += ref_energies[Z_val] - else: - if Z_val not in missing_Z_refs: missing_Z_refs.append(Z_val) - if missing_Z_refs: - st.warning(f"Reference energy for atomic number(s) {sorted(list(set(missing_Z_refs)))} " - f"not found in '{chosen_ref_list_name}' list (max Z defined: {len(ref_energies)-1}). " - "These atoms are treated as having 0 reference energy.") - else: - st.error(f"Could not find or determine reference energy list for FairChem model: '{selected_model}' " - f"and UMA task type: '{selected_task_type}'. Cannot calculate atomization/cohesive energy.") - results["Error"] = "Missing FairChem reference energies." - calculation_possible = False - - else:# == "MACE": - st.write("Calculating isolated atom energies with MACE...") - unique_atomic_numbers = sorted(list(set(atomic_numbers))) - atom_counts = {Z_unique: np.count_nonzero(atomic_numbers == Z_unique) for Z_unique in unique_atomic_numbers} - - progress_text = "Calculating isolated atom energies: 0% complete" - mace_progress_bar = st.progress(0, text=progress_text) - - for i, Z_unique in enumerate(unique_atomic_numbers): - isolated_atom = Atoms(numbers=[Z_unique], cell=[20, 20, 20], pbc=False) - if not hasattr(isolated_atom, 'info'): isolated_atom.info = {} - isolated_atom.info["charge"] = 0 - isolated_atom.info["spin"] = 0 - isolated_atom.calc = calc # Use the same MACE calculator - - E_isolated_atom_type = isolated_atom.get_potential_energy() - E_isolated_atoms_total += E_isolated_atom_type * atom_counts[Z_unique] - - progress_val = (i + 1) / len(unique_atomic_numbers) - mace_progress_bar.progress(progress_val, text=f"Calculating isolated atom energies for Z={Z_unique}: {int(progress_val*100)}% complete") - mace_progress_bar.empty() - - if calculation_possible: - is_periodic = any(calc_atoms.pbc) - if is_periodic: - cohesive_E = (E_isolated_atoms_total - E_system) / num_atoms - results["Cohesive Energy"] = f"{cohesive_E:.6f} eV/atom" - else: - atomization_E = E_isolated_atoms_total - E_system - results["Atomization Energy"] = f"{atomization_E:.6f} eV" - - results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV" - results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV" - - elif "Geometry Optimization" in task: # Handles both Geometry and Cell+Geometry Opt - is_periodic = any(calc_atoms.pbc) - opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms - # Create temporary trajectory file - traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name - if optimizer_type == "BFGS": - opt = BFGS(opt_atoms_obj, trajectory=traj_filename) - - elif optimizer_type == "BFGSLineSearch": - opt = BFGSLineSearch(opt_atoms_obj, trajectory=traj_filename) - - elif optimizer_type == "LBFGS": - opt = LBFGS(opt_atoms_obj, trajectory=traj_filename) - - elif optimizer_type == "LBFGSLineSearch": - opt = LBFGSLineSearch(opt_atoms_obj, trajectory=traj_filename) - - elif optimizer_type == "FIRE": - opt = FIRE(opt_atoms_obj, trajectory=traj_filename) - - elif optimizer_type == "GPMin": - opt = GPMin(opt_atoms_obj, trajectory=traj_filename) - - elif optimizer_type == "MDMin": - opt = MDMin(opt_atoms_obj, trajectory=traj_filename) - - # opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly - opt.attach(lambda: streamlit_log(opt), interval=1) - - st.write(f"Running {task.lower()}...") - is_converged = opt.run(fmax=fmax, steps=max_steps) - - energy = calc_atoms.get_potential_energy() - forces = calc_atoms.get_forces() - max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0 - - results["Final Energy"] = f"{energy:.6f} eV" - results["Final Maximum Force"] = f"{max_force:.6f} eV/Å" - results["Steps Taken"] = opt.get_number_of_steps() - results["Converged"] = "Yes" if is_converged else "No" - if task == "Cell + Geometry Optimization": - results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4).tolist() - - st.success("Calculation completed successfully!") - st.markdown("### Results") - for key, value in results.items(): - st.write(f"**{key}:** {value}") - - if "Optimization" in task and "Final Energy" in results: # Check if opt was successful - st.markdown("### Optimized Structure") - - opt_view = get_structure_viz2(calc_atoms, style=viz_style, show_unit_cell=True, width=400, height=400) - st.components.v1.html(opt_view._make_html(), width=400, height=400) - - with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt: - if is_periodic: - write(tmp_file_opt.name, calc_atoms, format="extxyz") - else: - write(tmp_file_opt.name, calc_atoms, format="xyz") - tmp_filepath_opt = tmp_file_opt.name - - with open(tmp_filepath_opt, 'r') as file_opt: - xyz_content_opt = file_opt.read() - @st.fragment - def show_optimized_structure_download_button(): - # st.button("Release the balloons", help="Fragment rerun") - # st.balloons() - st.download_button( - label="Download Optimized Structure (XYZ)", - data=xyz_content_opt, - file_name="optimized_structure.xyz", - mime="chemical/x-xyz" - ) - show_optimized_structure_download_button() - # --- Energy vs. Optimization Cycles Plot --- - @st.fragment - def show_energy_plot(traj_filename): - from ase.io import read - import pandas as pd - import plotly.express as px - import os - - if os.path.exists(traj_filename): - try: - trajectory = read(traj_filename, index=":") - - # Extract energy and step number - energies = [atoms.get_potential_energy() for atoms in trajectory] - steps = list(range(len(energies))) - - # Create a DataFrame for Plotly - data = { - "Optimization Cycle": steps, - "Energy (eV)": energies - } - df = pd.DataFrame(data) - - st.markdown("### Energy Profile During Optimization") - - # Create the Plotly figure - fig = px.line( - df, - x="Optimization Cycle", - y="Energy (eV)", - markers=True, # Show points for each step - title="Energy Convergence vs. Optimization Cycle", - ) - - # Enhance aesthetics - fig.update_layout( - xaxis_title="Optimization Cycle", - yaxis_title="Energy (eV)", - hovermode="x unified", - template="plotly_white", # Clean, professional look - font=dict(size=12), - title_x=0.5, # Center the title - ) - - # Highlight the converged energy (optional: useful if the plot is zoomed out) - fig.add_hline( - y=energies[-1], - line_dash="dot", - line_color="red", - annotation_text=f"Final Energy: {energies[-1]:.4f} eV", - annotation_position="bottom right" - ) - - # Render the plot in Streamlit - st.plotly_chart(fig, use_container_width=True) - - except Exception as e: - st.error(f"Error generating energy plot: {e}") - else: - st.warning("Cannot generate energy plot: Trajectory file not found.") - - show_energy_plot(traj_filename) - # --- End of Energy Plot Code --- - os.unlink(tmp_filepath_opt) - - @st.fragment - def show_trajectory_and_controls(): - from ase.io import read - import py3Dmol - - if "traj_frames" not in st.session_state: - if os.path.exists(traj_filename): - try: - trajectory = read(traj_filename, index=":") - st.session_state.traj_frames = trajectory - st.session_state.traj_index = 0 - except Exception as e: - st.error(f"Error reading trajectory: {e}") - return - # finally: - # os.unlink(traj_filename) - else: - st.warning("Trajectory file not found.") - return - - trajectory = st.session_state.traj_frames - index = st.session_state.traj_index - - st.markdown("### Optimization Trajectory") - st.write(f"Captured {len(trajectory)} optimization steps") - - # Navigation Buttons - col1, col2, col3, col4 = st.columns(4) - with col1: - if st.button("⏮ First"): - st.session_state.traj_index = 0 - with col2: - if st.button("◀ Previous") and index > 0: - st.session_state.traj_index -= 1 - with col3: - if st.button("Next ▶") and index < len(trajectory) - 1: - st.session_state.traj_index += 1 - with col4: - if st.button("Last ⏭"): - st.session_state.traj_index = len(trajectory) - 1 - - # Show current frame - current_atoms = trajectory[st.session_state.traj_index] - st.write(f"Frame {st.session_state.traj_index + 1}/{len(trajectory)}") - - def atoms_to_xyz_string(atoms, step_idx=None): - xyz_str = f"{len(atoms)}\n" - if step_idx is not None: - xyz_str += f"Step {step_idx}, Energy = {atoms.get_potential_energy():.6f} eV\n" - else: - xyz_str += f"Energy = {atoms.get_potential_energy():.6f} eV\n" - for atom in atoms: - xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" - return xyz_str - - traj_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400) - st.components.v1.html(traj_view._make_html(), width=400, height=400) - - # Download button for entire trajectory - trajectory_xyz = "" - for i, atoms in enumerate(trajectory): - trajectory_xyz += atoms_to_xyz_string(atoms, i) - st.download_button( - label="Download Optimization Trajectory (XYZ)", - data=trajectory_xyz, - file_name="optimization_trajectory.xyz", - mime="chemical/x-xyz" - ) - - show_trajectory_and_controls() - elif task == "Global Optimization": - st.info(f"Starting Global Optimization using {global_method}...") - - # Create temporary trajectory file to store the "hopping" steps - traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name - - # Container for live updates - log_container = st.empty() - global_min_energy = 0 - - def global_log(opt_instance): - """Helper to log global optimization steps.""" - global global_min_energy - current_e = opt_instance.atoms.get_potential_energy() - # For BasinHopping, nsteps is available. For others, we might need a counter. - step = getattr(opt_instance, 'nsteps', 'N/A') - log_container.write(f"Global Step: {step} | Energy: {current_e:.6f} eV") - if current_e < global_min_energy: - global_min_energy = current_e - - if global_method == "Basin Hopping": - # Basin Hopping requires Temperature in eV (kB * T) - kT = temperature_K * kB - - # Create the wrapper for the hack needed to enforce the optimization to stop when it reaches a certain number of steps - class LimitedLBFGS(LBFGS): - def run(self, fmax=0.05, steps=None): - # 'steps' here overrides whatever BasinHopping tries to do. - # Set your desired max local steps (e.g., 200) - return super().run(fmax=fmax, steps=100) - # Initialize Basin Hopping with the trajectory file - bh = BasinHopping(calc_atoms, - temperature=kT, - dr=dr_amp, - optimizer=LimitedLBFGS, - fmax=fmax_local, - trajectory=traj_filename) # Log steps to file automatically - - # Attach the live logger - bh.attach(lambda: global_log(bh), interval=1) - - # Run the optimization - bh.run(global_steps) - - results["Global Minimum Energy"] = f"{global_min_energy:.6f} eV" - results["Steps Taken"] = global_steps - results["Converged"] = "N/A (Global Search)" - - elif global_method == "Minima Hopping": - # Minima Hopping manages its own internal optimizers and doesn't accept a 'trajectory' - # file argument in the same way BasinHopping does in __init__. - opt = MinimaHopping(calc_atoms, - T0=temperature_K, - fmax=fmax_local, - optimizer=LBFGS) - - # We run it. Live logging is harder here without subclassing, - # so we rely on the final output for the trajectory. - opt(totalsteps=global_steps) - - results["Current Energy"] = f"{calc_atoms.get_potential_energy():.6f} eV" - - # Post-processing: MinimaHopping stores visited minima in an internal list usually. - # We explicitly write the found minima to the trajectory file so the visualizer below works. - # Note: opt.minima is a list of Atoms objects found during the hop. - if hasattr(opt, 'minima'): - from ase.io import write - write(traj_filename, opt.minima) - else: - # Fallback if specific version doesn't store list, just save final - write(traj_filename, calc_atoms) - - st.success("Global Optimization Complete!") - - st.markdown("### Results") - for key, value in results.items(): - st.write(f"**{key}:** {value}") - - # --- Visualization and Downloading (Fragmented) --- - - # 1. Clean up the temp file path for reading - # We define the visualizer function using @st.fragment to prevent full re-runs - - @st.fragment - def show_global_trajectory_and_dl(): - from ase.io import read - import py3Dmol - - # Helper to convert atoms list to XYZ string for the single download file - def atoms_list_to_xyz_string(atoms_list): - xyz_str = "" - for i, atoms in enumerate(atoms_list): - xyz_str += f"{len(atoms)}\n" - xyz_str += f"Step {i}, Energy = {atoms.get_potential_energy():.6f} eV\n" - for atom in atoms: - xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" - return xyz_str - - if "global_traj_frames" not in st.session_state: - if os.path.exists(traj_filename): - try: - # Read the trajectory we just created - trajectory = read(traj_filename, index=":") - st.session_state.global_traj_frames = trajectory - st.session_state.global_traj_index = 0 - except Exception as e: - st.error(f"Error reading trajectory: {e}") - return - else: - st.warning("Trajectory file not generated.") - return - - trajectory = st.session_state.global_traj_frames - - if not trajectory: - st.warning("No steps recorded in trajectory.") - return - - index = st.session_state.global_traj_index - - st.markdown("### Global Search Trajectory") - st.write(f"Captured {len(trajectory)} hopping steps (Local Minima)") - - # Navigation Controls - col1, col2, col3, col4 = st.columns(4) - with col1: - if st.button("⏮ First", key="g_first"): - st.session_state.global_traj_index = 0 - with col2: - if st.button("◀ Previous", key="g_prev") and index > 0: - st.session_state.global_traj_index -= 1 - with col3: - if st.button("Next ▶", key="g_next") and index < len(trajectory) - 1: - st.session_state.global_traj_index += 1 - with col4: - if st.button("Last ⏭", key="g_last"): - st.session_state.global_traj_index = len(trajectory) - 1 - - # Display Visualization - current_atoms = trajectory[st.session_state.global_traj_index] - st.write(f"Frame {st.session_state.global_traj_index + 1}/{len(trajectory)} | E = {current_atoms.get_potential_energy():.4f} eV") - - viz_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400) - st.components.v1.html(viz_view._make_html(), width=400, height=400) - - # Download Logic - full_xyz_content = atoms_list_to_xyz_string(trajectory) - - st.download_button( - label="Download Trajectory (XYZ)", - data=full_xyz_content, - file_name="global_optimization_path.xyz", - mime="chemical/x-xyz" - ) - - # Separate Download for just the Best Structure (Last frame usually in BH, or sorted) - # Often in BH, the last frame is the accepted state, but not necessarily the global min seen *ever*. - # But usually, we want the lowest energy one. - energies = [a.get_potential_energy() for a in trajectory] - best_idx = np.argmin(energies) - best_atoms = trajectory[best_idx] - - # Create XYZ for single best - with tempfile.NamedTemporaryFile(mode='w', suffix=".xyz", delete=False) as tmp_best: - write(tmp_best.name, best_atoms) - tmp_best_name = tmp_best.name - - with open(tmp_best_name, "r") as f: - st.download_button( - label=f"Download Best Structure (E={energies[best_idx]:.4f} eV)", - data=f.read(), - file_name="best_global_structure.xyz", - mime="chemical/x-xyz" - ) - os.unlink(tmp_best_name) - - # Call the fragment function - show_global_trajectory_and_dl() - - # Cleanup main trajectory file after loading it into session state if desired, - # though keeping it until session end is safer for re-reads. - # os.unlink(traj_filename) - elif task == "Vibrational Mode Analysis": - # Conversion factors - from ase.units import kB as kB_eVK, _Nav, J # ASE's constants - from scipy.constants import physical_constants - kB_JK = physical_constants["Boltzmann constant"][0] # J/K - is_periodic = any(calc_atoms.pbc) - st.write("Running vibrational mode analysis using finite differences...") - - natoms = len(calc_atoms) - is_linear = False # Set manually or auto-detect - nmodes_expected = 3 * natoms - (5 if is_linear else 6) - - # Create temporary directory to store .vib files - with tempfile.TemporaryDirectory() as tmpdir: - vib = Vibrations(calc_atoms, name=os.path.join(tmpdir, 'vib')) - - with st.spinner("Calculating vibrational modes... This may take a few minutes."): - vib.run() - freqs = vib.get_frequencies() - energies = vib.get_energies() - - print('\n\n\n\n\n\n\n\n') - # vib.get_hessian_2d() - # st.write(vib.summary()) - # print('\n') - # vib.tabulate() - - freqs_cm = freqs - freqs_eV = energies - - # Classify frequencies - mode_data = [] - for i, freq in enumerate(freqs_cm): - if freq < 0: - label = "Imaginary" - elif abs(freq) < 500: - label = "Low" - else: - label = "Physical" - mode_data.append({ - "Mode": i + 1, - "Frequency (cm⁻¹)": round(freq, 2), - "Type": label - }) - - df_modes = pd.DataFrame(mode_data) - - # Display summary and mode count - st.success("Vibrational analysis completed.") - st.write(f"Number of atoms: {natoms}") - st.write(f"Expected vibrational modes: {nmodes_expected}") - st.write(f"Found {len(freqs_cm)} modes (including translational/rotational modes).") - - # Show table of modes - st.write("### Vibrational Mode Summary") - st.dataframe(df_modes, use_container_width=True) - - # Store in results dictionary - results["Vibrational Modes"] = df_modes.to_dict(orient="records") - - # Histogram plot of vibrational frequencies - st.write("### Frequency Distribution Histogram") - fig, ax = plt.subplots() - ax.hist(freqs_cm, bins=30, color='skyblue', edgecolor='black') - ax.set_xlabel("Frequency (cm⁻¹)") - ax.set_ylabel("Number of Modes") - ax.set_title("Distribution of Vibrational Frequencies") - st.pyplot(fig) - - # CSV download - csv_buffer = io.StringIO() - df_modes.to_csv(csv_buffer, index=False) - st.download_button( - label="Download Vibrational Frequencies (CSV)", - data=csv_buffer.getvalue(), - file_name="vibrational_modes.csv", - mime="text/csv" - ) - # -------- Thermodynamic Analysis for Molecules -------- - if not is_periodic: - - # Filter physical frequencies > 1 cm⁻¹ (to avoid numerical issues) - physical_freqs_eV = np.array([f for f in freqs_eV if f > 1e-5]) - - # Zero-point vibrational energy (ZPE) - ZPE = 0.5 * np.sum(physical_freqs_eV) # in eV - - # Vibrational entropy (in eV/K) - vib_entropy = 0.0 - for f in physical_freqs_eV: - x = f / (kB_eVK * T) - vib_entropy += (x / (np.exp(x) - 1) - np.log(1 - np.exp(-x))) - - S_vib_eVK = kB_eVK * vib_entropy # eV/K - S_vib_JmolK = S_vib_eVK * J * _Nav # J/mol·K - - results["ZPE (eV)"] = ZPE.real - results["Vibrational Entropy (eV/K)"] = S_vib_eVK - results["Vibrational Entropy (J/mol·K)"] = S_vib_JmolK - - st.write(f"**Zero-point vibrational energy (ZPE)**: {ZPE.real:.6f} eV") - st.write(f"**Vibrational entropy**: {S_vib_eVK:.6f} eV/K") - - else: - st.info("Thermodynamic properties like ZPE and entropy are currently only meaningful for isolated molecules (non-periodic systems).") - elif task == "Phonons": - from ase.phonons import Phonons - - st.write("### Phonon Band Structure and Density of States") - is_periodic = any(calc_atoms.pbc) - - if not is_periodic: - st.error("Phonon calculations require a periodic structure. Please use a periodic system.") - else: - with tempfile.TemporaryDirectory() as tmpdir: - st.info("Running phonon calculation using finite displacements...") - sc = (7, 7, 7) - - # Create phonon object - ph = Phonons(calc_atoms, calc_atoms.calc, supercell=sc, delta=0.001, name=os.path.join(tmpdir, 'phonon')) - - with st.spinner("Displacing atoms and computing forces..."): - ph.run() - - # Build dynamical matrix - ph.read(acoustic=True) - ph.clean() - - # Band path and DOS - # path = calc_atoms.cell.bandpath('GXULGK', npoints=100) - path = calc_atoms.cell.bandpath('GXKGL', npoints=100) - # path = calc_atoms.cell.bandpath(eps=0.00001) - bs = ph.get_band_structure(path) - dos = ph.get_dos(kpts=(20, 20, 20)).sample_grid(npts=100, width=1e-3) - - # Plotting - fig = plt.figure(figsize=(7, 4)) - ax = fig.add_axes([0.12, 0.07, 0.67, 0.85]) - - emax = 0.075 - bs.plot(ax=ax, emin=0.0, emax=emax) - - dosax = fig.add_axes([0.8, 0.07, 0.17, 0.85]) - dosax.fill_between( - dos.get_weights(), - dos.get_energies(), - y2=0, - color='grey', - edgecolor='k', - lw=1, - ) - dosax.set_ylim(0, emax) - dosax.set_yticks([]) - dosax.set_xticks([]) - dosax.set_xlabel('DOS', fontsize=14) - - st.pyplot(fig) - st.success("Phonon band structure and DOS successfully plotted.") - except Exception as e: - st.error(f"🔴 Calculation error: {str(e)}") - st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).") - import traceback - st.error(f"Traceback: {traceback.format_exc()}") - -else: - st.info("👋 Welcome! Please select or upload a structure using the sidebar options to begin.") - -st.markdown("---") -with st.expander('ℹ️ About This App & Foundational MLIPs'): - st.write(""" - **Test, compare, and benchmark universal machine learning interatomic potentials (MLIPs).** - This application allows you to perform atomistic simulations using pre-trained foundational MLIPs - from the MACE, MatterSim (Microsoft), SevenNet, Orb (Orbital Materials) and FairChem (Meta AI) developers and researchers. - - **Features:** - - Upload/Paste structure files (XYZ, CIF, POSCAR, etc.), import from Materials Project/PubChem or use built-in examples. - - Select from various MACE, ORB, SevenNet, MatterSim and FairChem models. - - Calculate energies, forces, cohesive/atomization energy, vibrational modes and perform geometry/cell optimizations. - - Visualize atomic structures in 3D and download results, optimized structures and optimization trajectories. - - **Quick Start:** - 1. **Input**: Choose an input method in the sidebar (e.g., "Select Example"). - 2. **Model**: Pick a model type (MACE/FairChem/MatterSim/ORB/SevenNet) and specific model. For FairChem UMA, select the appropriate task type (e.g., `omol` for molecules, `omat` for materials). - For models trained on OMOL25 dataset (whenever the model name contains `omol`) then the user also needs to provide a charge and spin multiplicity (`2S+1`) value. By default the charge is set to zero and spin multiplicity to 1 (S=0). - 3. **Task**: Select a calculation task (e.g., "Energy Calculation", "Atomization/Cohesive Energy", "Geometry Optimization"). - 4. **Run**: Click "Run Calculation" and view the results. - - **Atomization/Cohesive Energy Notes:** - - **Atomization Energy** ($E_{\\text{atomization}} = \sum E_{\\text{isolated atoms}} - E_{\\text{molecule}}$) is typically for non-periodic systems (molecules). - - **Cohesive Energy** ($E_{\\text{cohesive}} = (\sum E_{\\text{isolated atoms}} - E_{\\text{bulk system}}) / N_{\\text{atoms}}$) is for periodic systems. - - For **MACE models**, isolated atom energies are computed on-the-fly. - - For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references. - """) - - -with st.expander('🔧 Tech Stack & System Information'): - import platform - import psutil - - st.markdown("### System Information") - - col1, col2 = st.columns(2) - - with col1: - st.write("**Operating System:**") - st.write(f"- OS: {platform.system()} {platform.release()}") - st.write(f"- Version: {platform.version()}") - st.write(f"- Architecture: {platform.machine()}") - st.write(f"- Processor: {platform.processor()}") - - st.write("\n**Python Environment:**") - st.write(f"- Python Version: {platform.python_version()}") - st.write(f"- Python Implementation: {platform.python_implementation()}") - - with col2: - st.write("**Hardware Resources:**") - st.write(f"- CPU Cores: {psutil.cpu_count(logical=False)} physical, {psutil.cpu_count(logical=True)} logical") - st.write(f"- CPU Usage: {psutil.cpu_percent(interval=1)}%") - - memory = psutil.virtual_memory() - st.write(f"- Total RAM: {memory.total / (1024**3):.2f} GB") - st.write(f"- Available RAM: {memory.available / (1024**3):.2f} GB") - st.write(f"- RAM Usage: {memory.percent}%") - - disk = psutil.disk_usage('/') - st.write(f"- Total Disk Space: {disk.total / (1024**3):.2f} GB") - st.write(f"- Free Disk Space: {disk.free / (1024**3):.2f} GB") - st.write(f"- Disk Usage: {disk.percent}%") - - st.markdown("### Package Versions") - - packages_to_check = [ - 'streamlit', 'torch', 'numpy', 'ase', 'py3Dmol', - 'mace-torch', 'fairchem-core', 'orb-models', 'sevenn', - 'pandas', 'matplotlib', 'scipy', 'yaml', 'huggingface-hub' - ] - - if mattersim_available: - packages_to_check.append('mattersim') - - package_versions = {} - for package in packages_to_check: - try: - version = pkg_resources.get_distribution(package).version - package_versions[package] = version - except pkg_resources.DistributionNotFound: - package_versions[package] = "Not installed" - - # Display in two columns - col1, col2 = st.columns(2) - items = list(package_versions.items()) - mid_point = len(items) // 2 - - with col1: - for package, version in items[:mid_point]: - st.write(f"**{package}:** {version}") - - with col2: - for package, version in items[mid_point:]: - st.write(f"**{package}:** {version}") - - # PyTorch specific information - st.markdown("### PyTorch Configuration") - st.write(f"**PyTorch Version:** {torch.__version__}") - st.write(f"**CUDA Available:** {torch.cuda.is_available()}") - if torch.cuda.is_available(): - st.write(f"**CUDA Version:** {torch.version.cuda}") - st.write(f"**cuDNN Version:** {torch.backends.cudnn.version()}") - st.write(f"**Number of GPUs:** {torch.cuda.device_count()}") - for i in range(torch.cuda.device_count()): - st.write(f"**GPU {i}:** {torch.cuda.get_device_name(i)}") - else: - st.write("Running on CPU only") - -st.markdown("---") -st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem, SevenNet, ORB and ❤️") -st.markdown("Developed by [Dr. Manas Sharma](https://manas.bragitoff.com/) in the groups of [Prof. Ananth Govind Rajan Group](https://www.agrgroup.org/) and [Prof. Sudeep Punnathanam](https://chemeng.iisc.ac.in/sudeep/) at [IISc Bangalore](https://iisc.ac.in/)") \ No newline at end of file