Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import io | |
| import tempfile | |
| import torch | |
| # FOR CPU only mode | |
| # torch._dynamo.config.suppress_errors = True | |
| # Or disable compilation entirely | |
| # torch.backends.cudnn.enabled = False | |
| import numpy as np | |
| from ase import Atoms | |
| from ase.io import read, write | |
| from ase.optimize import BFGS, LBFGS, FIRE | |
| from ase.constraints import FixAtoms | |
| from ase.filters import FrechetCellFilter | |
| from ase.visualize import view | |
| import py3Dmol | |
| from mace.calculators import mace_mp | |
| from fairchem.core import pretrained_mlip, FAIRChemCalculator | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| from sevenn.calculator import SevenNetCalculator | |
| import pandas as pd | |
| import yaml # Added for FairChem reference energies | |
| import subprocess | |
| import sys | |
| import pkg_resources | |
| from ase.vibrations import Vibrations | |
| import matplotlib.pyplot as plt | |
| mattersim_available = True | |
| if mattersim_available: | |
| from mattersim.forcefield import MatterSimCalculator | |
| # try: | |
| # subprocess.check_call([sys.executable, "-m", "pip", "install", "mattersim"]) | |
| # except Exception as e: | |
| # print(f"Error during installation of mattersim: {e}") | |
| # try: | |
| # from mattersim.forcefield import MatterSimCalculator | |
| # mattersim_available = True | |
| # print("\n\n\n\n\n\n\nSuccessfully imported MatterSimCalculator.\n\n\n\n\n\n\n\n\n\n") | |
| # except ImportError as e: | |
| # print(f"Failed to import MatterSimCalculator: {e} \n\n\n\n\n\n\n\n") | |
| # mattersim_available = False | |
| # # Define version threshold | |
| # required_version = "2.0.0" | |
| # try: | |
| # installed_version = pkg_resources.get_distribution("numpy").version | |
| # if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version(required_version): | |
| # print(f"numpy version {installed_version} >= {required_version}. Installing numpy<2.0.0...") | |
| # subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0"]) | |
| # else: | |
| # print(f"numpy version {installed_version} is already < {required_version}. No action needed.") | |
| # except pkg_resources.DistributionNotFound: | |
| # print("numpy is not installed. Installing numpy<2.0.0...") | |
| # subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0"]) | |
| from huggingface_hub import login | |
| # try: | |
| # hf_token = st.secrets["HF_TOKEN"]["token"] | |
| # os.environ["HF_TOKEN"] = hf_token | |
| # login(token=hf_token) | |
| # except Exception as e: | |
| # print("streamlit hf secret not defined/assigned") | |
| try: | |
| hf_token = os.getenv("YOUR SECRET KEY") # Replace with your actual Hugging Face token or manage secrets appropriately | |
| if hf_token: | |
| login(token = hf_token) | |
| else: | |
| print("Hugging Face token not found. Some models might not be accessible.") | |
| except Exception as e: | |
| print(f"hf login error: {e}") | |
| os.environ["STREAMLIT_WATCHER_TYPE"] = "none" | |
| # YAML data for FairChem reference energies | |
| ELEMENT_REF_ENERGIES_YAML = """ | |
| oc20_elem_refs: | |
| - 0.0 | |
| - -0.16141512 | |
| - 0.03262098 | |
| - -0.04787699 | |
| - -0.06299825 | |
| - -0.14979306 | |
| - -0.11657468 | |
| - -0.10862579 | |
| - -0.10298174 | |
| - -0.03420248 | |
| - 0.02673997 | |
| - -0.03729558 | |
| - 0.00515243 | |
| - -0.07535697 | |
| - -0.13663351 | |
| - -0.12922852 | |
| - -0.11796547 | |
| - -0.07802946 | |
| - -0.00672682 | |
| - -0.04089589 | |
| - -0.00024177 | |
| - -1.74545186 | |
| - -1.54220241 | |
| - -1.0934019 | |
| - -1.16168372 | |
| - -1.23073475 | |
| - -0.78852824 | |
| - -0.71851599 | |
| - -0.52465053 | |
| - -0.02692092 | |
| - -0.00317922 | |
| - -0.06266862 | |
| - -0.10835274 | |
| - -0.12394474 | |
| - -0.11351727 | |
| - -0.07455817 | |
| - -0.00258354 | |
| - -0.04111325 | |
| - -0.02090265 | |
| - -1.89306078 | |
| - -1.30591887 | |
| - -0.63320009 | |
| - -0.26230344 | |
| - -0.2633669 | |
| - -0.5160055 | |
| - -0.95950798 | |
| - -1.45589361 | |
| - -0.0429969 | |
| - -0.00026949 | |
| - -0.05925609 | |
| - -0.09734631 | |
| - -0.12406852 | |
| - -0.11427538 | |
| - -0.07021442 | |
| - 0.01091345 | |
| - -0.05305289 | |
| - -0.02427209 | |
| - -0.19975668 | |
| - -1.71692859 | |
| - -1.53677781 | |
| - -3.89987009 | |
| - -10.70940462 | |
| - -6.71693816 | |
| - -0.28102249 | |
| - -8.86944824 | |
| - -7.95762687 | |
| - -7.13041437 | |
| - -6.64620014 | |
| - -5.11482482 | |
| - -4.42548227 | |
| - 0.00848295 | |
| - -0.06956227 | |
| - -2.6748853 | |
| - -2.21153293 | |
| - -1.67367741 | |
| - -1.07636151 | |
| - -0.79009981 | |
| - -0.16387243 | |
| - -0.18164401 | |
| - -0.04122529 | |
| - -0.00041833 | |
| - -0.05259382 | |
| - -0.0934314 | |
| - -0.11023834 | |
| - -0.10039175 | |
| - -0.06069209 | |
| - 0.01790437 | |
| - -0.04694024 | |
| - 0.00334084 | |
| - -0.06030621 | |
| - -0.58793619 | |
| - -1.27821808 | |
| - -4.97483577 | |
| - -5.66985655 | |
| - -8.43154622 | |
| - -11.15001317 | |
| - -12.95770812 | |
| - 0.0 | |
| - -14.47602729 | |
| - 0.0 | |
| odac_elem_refs: | |
| - 0.0 | |
| - -1.11737936 | |
| - -0.00011835 | |
| - -0.2941727 | |
| - -0.03868426 | |
| - -0.34862832 | |
| - -1.31552566 | |
| - -3.12457285 | |
| - -1.6052078 | |
| - -0.49653389 | |
| - -0.01137327 | |
| - -0.21957281 | |
| - -0.0008343 | |
| - -0.2750172 | |
| - -0.88417265 | |
| - -1.887378 | |
| - -0.94903558 | |
| - -0.31628167 | |
| - -0.02014536 | |
| - -0.15901053 | |
| - -0.00731884 | |
| - -1.96521355 | |
| - -1.89045209 | |
| - -2.53057428 | |
| - -5.43600675 | |
| - -5.09739336 | |
| - -3.03088746 | |
| - -1.23786562 | |
| - -0.40650749 | |
| - -0.2416017 | |
| - -0.01139188 | |
| - -0.26282496 | |
| - -0.82446455 | |
| - -1.70237206 | |
| - -0.84245376 | |
| - -0.28544892 | |
| - -0.02239991 | |
| - -0.14115912 | |
| - -0.02840799 | |
| - -2.09540994 | |
| - -1.85863996 | |
| - -1.12257399 | |
| - -4.32965355 | |
| - -3.30670045 | |
| - -1.19460755 | |
| - -1.26257601 | |
| - -1.46832888 | |
| - -0.19779414 | |
| - -0.0144274 | |
| - -0.23668767 | |
| - -0.70836953 | |
| - -1.43186113 | |
| - -0.71701186 | |
| - -0.24883129 | |
| - -0.01118184 | |
| - -0.13173447 | |
| - -0.0318395 | |
| - -0.41195547 | |
| - -1.23134873 | |
| - -2.03082996 | |
| - 0.1375954 | |
| - -5.45866275 | |
| - -7.59139905 | |
| - -5.99965965 | |
| - -8.43495767 | |
| - -2.6578407 | |
| - -7.77349787 | |
| - -5.30762201 | |
| - -5.15109657 | |
| - -4.41466995 | |
| - -0.02995219 | |
| - -0.2544495 | |
| - -3.23821202 | |
| - -3.45887214 | |
| - -4.53635003 | |
| - -4.60979468 | |
| - -2.90707964 | |
| - -1.28286153 | |
| - -0.57716664 | |
| - -0.18337108 | |
| - -0.01135944 | |
| - -0.22045398 | |
| - -0.66150479 | |
| - -1.32506342 | |
| - -0.66500178 | |
| - -0.22643927 | |
| - -0.00728197 | |
| - -0.11208472 | |
| - -0.00757856 | |
| - -0.21798637 | |
| - -0.91078787 | |
| - -1.78187161 | |
| - -3.89912261 | |
| - -3.94192659 | |
| - -7.59026042 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| omat_elem_refs: | |
| - 0.0 | |
| - -1.11700253 | |
| - 0.00079886 | |
| - -0.29731164 | |
| - -0.04129868 | |
| - -0.29106192 | |
| - -1.27751531 | |
| - -3.12342715 | |
| - -1.54797136 | |
| - -0.43969356 | |
| - -0.01250908 | |
| - -0.22855413 | |
| - -0.00943179 | |
| - -0.21707638 | |
| - -0.82619133 | |
| - -1.88667434 | |
| - -0.89093583 | |
| - -0.25816211 | |
| - -0.02414768 | |
| - -0.17662425 | |
| - -0.02568319 | |
| - -2.13001165 | |
| - -2.38688845 | |
| - -3.55934233 | |
| - -5.44700879 | |
| - -5.14749562 | |
| - -3.30662847 | |
| - -1.42167737 | |
| - -0.63181379 | |
| - -0.23449167 | |
| - -0.01146636 | |
| - -0.21291259 | |
| - -0.77939897 | |
| - -1.70148487 | |
| - -0.78386705 | |
| - -0.22690657 | |
| - -0.02245409 | |
| - -0.16092396 | |
| - -0.02798717 | |
| - -2.25685695 | |
| - -2.23690495 | |
| - -2.15347771 | |
| - -4.60251809 | |
| - -3.36416792 | |
| - -2.23062607 | |
| - -1.15550917 | |
| - -1.47553527 | |
| - -0.19918102 | |
| - -0.01475888 | |
| - -0.19767692 | |
| - -0.68005773 | |
| - -1.43073368 | |
| - -0.65790462 | |
| - -0.18915279 | |
| - -0.01179476 | |
| - -0.13507902 | |
| - -0.03056979 | |
| - -0.36017439 | |
| - -0.86279246 | |
| - -0.20573327 | |
| - -0.2734463 | |
| - -0.20046965 | |
| - -0.25444338 | |
| - -8.37972664 | |
| - -9.58424928 | |
| - -0.19466184 | |
| - -0.24860115 | |
| - -0.19531288 | |
| - -0.15401392 | |
| - -0.14577898 | |
| - -0.19655747 | |
| - -0.15645898 | |
| - -3.49380556 | |
| - -3.5317097 | |
| - -4.57108006 | |
| - -4.63425205 | |
| - -2.88247063 | |
| - -1.45679675 | |
| - -0.50290184 | |
| - -0.18521704 | |
| - -0.01123956 | |
| - -0.17483649 | |
| - -0.63132037 | |
| - -1.3248562 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - -0.24135757 | |
| - -1.04601971 | |
| - -2.04574044 | |
| - -3.84544799 | |
| - -7.28626119 | |
| - -7.3136314 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| omol_elem_refs: | |
| - 0.0 | |
| - -13.44558 | |
| - -78.82027 | |
| - -203.32564 | |
| - -398.94742 | |
| - -670.75275 | |
| - -1029.85403 | |
| - -1485.54188 | |
| - -2042.97832 | |
| - -2714.24015 | |
| - -3508.74317 | |
| - -4415.24203 | |
| - -5443.89712 | |
| - -6594.61834 | |
| - -7873.6878 | |
| - -9285.6593 | |
| - -10832.62132 | |
| - -12520.66852 | |
| - -14354.278 | |
| - -16323.54671 | |
| - -18436.47845 | |
| - -20696.18244 | |
| - -23110.5386 | |
| - -25682.99429 | |
| - -28418.37804 | |
| - -31317.92317 | |
| - -34383.42519 | |
| - -37623.46835 | |
| - -41039.92413 | |
| - -44637.38634 | |
| - -48417.14864 | |
| - -52373.87849 | |
| - -56512.76952 | |
| - -60836.14871 | |
| - -65344.28833 | |
| - -70041.24251 | |
| - -74929.56277 | |
| - -653.64777 | |
| - -833.31922 | |
| - -1038.0281 | |
| - -1273.96788 | |
| - -1542.45481 | |
| - -1850.74158 | |
| - -2193.91654 | |
| - -2577.18734 | |
| - -3004.13604 | |
| - -3477.52796 | |
| - -3997.31825 | |
| - -4563.75804 | |
| - -5171.82293 | |
| - -5828.85334 | |
| - -6535.61529 | |
| - -7291.54792 | |
| - -8099.87914 | |
| - -8962.17916 | |
| - -546.03214 | |
| - -690.6089 | |
| - -854.11237 | |
| - -12923.04096 | |
| - -14064.26124 | |
| - -15272.68689 | |
| - -16550.20551 | |
| - -17900.36515 | |
| - -19323.23406 | |
| - -20829.08848 | |
| - -22428.73258 | |
| - -24078.68008 | |
| - -25794.42097 | |
| - -27616.6819 | |
| - -29523.5526 | |
| - -31526.68012 | |
| - -33615.37779 | |
| - -1300.17791 | |
| - -1544.40924 | |
| - -1818.62298 | |
| - -2123.14417 | |
| - -2461.76028 | |
| - -2833.76287 | |
| - -3242.79895 | |
| - -3690.363 | |
| - -4174.99772 | |
| - -4691.75674 | |
| - -5245.36013 | |
| - -5838.12005 | |
| - -6469.07296 | |
| - -7140.86455 | |
| - -7854.60638 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| omc_elem_refs: | |
| - 0.0 | |
| - -0.02831808 | |
| - 4.512e-05 | |
| - -0.03227157 | |
| - -0.03842519 | |
| - -0.05829283 | |
| - -0.0845041 | |
| - -0.08806738 | |
| - -0.09021346 | |
| - -0.06669846 | |
| - -0.01218631 | |
| - -0.03650269 | |
| - -0.00059093 | |
| - -0.05787736 | |
| - -0.08730952 | |
| - -0.0975534 | |
| - -0.09264199 | |
| - -0.07124762 | |
| - -0.02374602 | |
| - -0.05299112 | |
| - -0.02631476 | |
| - -1.7772147 | |
| - -1.25083444 | |
| - -0.79579447 | |
| - -0.49099317 | |
| - -0.31414986 | |
| - -0.20292182 | |
| - -0.14011632 | |
| - -0.09929659 | |
| - -0.03771207 | |
| - -0.01117902 | |
| - -0.06168715 | |
| - -0.08873364 | |
| - -0.09512942 | |
| - -0.09035978 | |
| - -0.06910849 | |
| - -0.02244872 | |
| - -0.05303651 | |
| - -0.02871903 | |
| - -1.94805417 | |
| - -1.33379896 | |
| - -0.69169331 | |
| - -0.26184306 | |
| - -0.20631599 | |
| - -0.48251608 | |
| - -0.96911893 | |
| - -1.47569462 | |
| - -0.03845194 | |
| - -0.0142445 | |
| - -0.07118991 | |
| - -0.09940292 | |
| - -0.09235056 | |
| - -0.08755943 | |
| - -0.06544925 | |
| - -0.01246646 | |
| - -0.04692937 | |
| - -0.03225123 | |
| - -0.26086039 | |
| - -27.20024339 | |
| - -0.08412926 | |
| - -0.08225924 | |
| - -0.07799715 | |
| - -0.07806185 | |
| - 0.00043759 | |
| - -0.07459766 | |
| - 0.0 | |
| - -0.06842841 | |
| - -0.07758266 | |
| - -0.07025152 | |
| - -0.08055003 | |
| - -0.07118177 | |
| - -0.07159568 | |
| - -2.69202862 | |
| - -2.21926765 | |
| - -1.679756 | |
| - -1.06135075 | |
| - -0.4554231 | |
| - -0.14488432 | |
| - -0.18377098 | |
| - -0.03603118 | |
| - -0.01076585 | |
| - -0.06381411 | |
| - -0.0905623 | |
| - -0.10095787 | |
| - -0.09501217 | |
| - -0.0574478 | |
| - -0.00599173 | |
| - -0.04134751 | |
| - -0.0082683 | |
| - -0.08704692 | |
| - -0.49656425 | |
| - -5.24233138 | |
| - -2.32542606 | |
| - -4.3376616 | |
| - -5.96430676 | |
| - 0.0 | |
| - 0.0 | |
| - -0.03842519 | |
| - 0.0 | |
| - 0.0 | |
| """ | |
| try: | |
| ELEMENT_REF_ENERGIES = yaml.safe_load(ELEMENT_REF_ENERGIES_YAML) | |
| except yaml.YAMLError as e: | |
| # st.error(f"Error parsing YAML reference energies: {e}") # st objects can only be used in main script flow | |
| print(f"Error parsing YAML reference energies: {e}") | |
| ELEMENT_REF_ENERGIES = {} # Fallback | |
| # Check if running on Streamlit Cloud vs locally | |
| is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud' | |
| MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud | |
| MAX_ATOMS_CLOUD_UMA = 500 | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="MLIP Playground - Run, Test and Benchmark MLIPs", | |
| page_icon="🧪", | |
| layout="wide" | |
| ) | |
| # Title and description | |
| st.markdown('## MLIP Playground', unsafe_allow_html=True) | |
| st.write('#### Run, test and compare 22 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials') | |
| st.markdown('Upload molecular structure files or select from predefined examples, then compute energies and forces using foundation models such as those from MACE or FairChem (Meta).', unsafe_allow_html=True) | |
| # Create a directory for sample structures if it doesn't exist | |
| SAMPLE_DIR = "sample_structures" | |
| os.makedirs(SAMPLE_DIR, exist_ok=True) | |
| # Dictionary of sample structures | |
| SAMPLE_STRUCTURES = { | |
| "Water": "H2O.xyz", | |
| "Methane": "CH4.xyz", | |
| "Benzene": "C6H6.xyz", | |
| "Ethane": "C2H6.xyz", | |
| "Caffeine": "caffeine.xyz", | |
| "Ibuprofen": "ibuprofen.xyz", | |
| "Silicon": "Si.cif", | |
| "hBN Monolayer (4x4)": "hBN_monolayer_4x4_supercell.extxyz", | |
| } | |
| def get_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400, | |
| show_path=True, path_color='red', path_radius=0.02): | |
| """ | |
| Visualize optimization trajectory with multiple frames | |
| Args: | |
| trajectory: List of ASE atoms objects representing the optimization steps | |
| style: Visualization style ('stick', 'ball', 'ball-stick') | |
| show_unit_cell: Whether to show unit cell | |
| show_path: Whether to show trajectory paths for each atom | |
| path_color: Color of trajectory paths | |
| path_radius: Radius of trajectory path cylinders | |
| """ | |
| if not trajectory: | |
| return None | |
| view = py3Dmol.view(width=width, height=height) | |
| # Add all frames to the viewer | |
| for frame_idx, atoms_obj in enumerate(trajectory): | |
| xyz_str = "" | |
| xyz_str += f"{len(atoms_obj)}\n" | |
| xyz_str += f"Frame {frame_idx}\n" | |
| for atom in atoms_obj: | |
| xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" | |
| view.addModel(xyz_str, "xyz") | |
| # Set style for all models | |
| if style.lower() == 'ball-stick': | |
| view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}}) | |
| elif style.lower() == 'stick': | |
| view.setStyle({'stick': {}}) | |
| elif style.lower() == 'ball': | |
| view.setStyle({'sphere': {'scale': 0.4}}) | |
| else: | |
| view.setStyle({'stick': {'radius': 0.15}}) | |
| # Add trajectory paths | |
| if show_path and len(trajectory) > 1: | |
| for atom_idx in range(len(trajectory[0])): | |
| for frame_idx in range(len(trajectory) - 1): | |
| start_pos = trajectory[frame_idx][atom_idx].position | |
| end_pos = trajectory[frame_idx + 1][atom_idx].position | |
| view.addCylinder({ | |
| 'start': {'x': start_pos[0], 'y': start_pos[1], 'z': start_pos[2]}, | |
| 'end': {'x': end_pos[0], 'y': end_pos[1], 'z': end_pos[2]}, | |
| 'radius': path_radius, | |
| 'color': path_color, | |
| 'alpha': 0.5 | |
| }) | |
| # Add unit cell for the last frame | |
| if show_unit_cell and trajectory[-1].pbc.any(): | |
| cell = trajectory[-1].get_cell() | |
| origin = np.array([0.0, 0.0, 0.0]) | |
| if cell is not None and cell.any(): | |
| edges = [ | |
| (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]), | |
| (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]), | |
| (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]), | |
| (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]), | |
| (cell[0] + cell[1], cell[0] + cell[1] + cell[2]) | |
| ] | |
| for start, end in edges: | |
| view.addCylinder({ | |
| 'start': {'x': start[0], 'y': start[1], 'z': start[2]}, | |
| 'end': {'x': end[0], 'y': end[1], 'z': end[2]}, | |
| 'radius': 0.05, 'color': 'black', 'alpha': 0.7 | |
| }) | |
| view.zoomTo() | |
| view.setBackgroundColor('white') | |
| return view | |
| def get_animated_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400): | |
| """ | |
| Create an animated trajectory visualization | |
| """ | |
| if not trajectory: | |
| return None | |
| view = py3Dmol.view(width=width, height=height) | |
| # Add all frames | |
| for frame_idx, atoms_obj in enumerate(trajectory): | |
| xyz_str = "" | |
| xyz_str += f"{len(atoms_obj)}\n" | |
| xyz_str += f"Frame {frame_idx}\n" | |
| for atom in atoms_obj: | |
| xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" | |
| view.addModel(xyz_str, "xyz") | |
| # Set style | |
| if style.lower() == 'ball-stick': | |
| view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}}) | |
| elif style.lower() == 'stick': | |
| view.setStyle({'stick': {}}) | |
| elif style.lower() == 'ball': | |
| view.setStyle({'sphere': {'scale': 0.4}}) | |
| else: | |
| view.setStyle({'stick': {'radius': 0.15}}) | |
| # Add unit cell for last frame | |
| if show_unit_cell and trajectory[-1].pbc.any(): | |
| cell = trajectory[-1].get_cell() | |
| origin = np.array([0.0, 0.0, 0.0]) | |
| if cell is not None and cell.any(): | |
| edges = [ | |
| (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]), | |
| (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]), | |
| (cell[0] + cell[1], cell[0] + cell[1] + cell[2]), | |
| (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]), | |
| (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]) | |
| ] | |
| for start, end in edges: | |
| view.addCylinder({ | |
| 'start': {'x': start[0], 'y': start[1], 'z': start[2]}, | |
| 'end': {'x': end[0], 'y': end[1], 'z': end[2]}, | |
| 'radius': 0.05, 'color': 'black', 'alpha': 0.7 | |
| }) | |
| view.zoomTo() | |
| view.setBackgroundColor('white') | |
| # Enable animation | |
| view.animate({'loop': 'forward', 'reps': 0, 'interval': 500}) | |
| return view | |
| # Streamlit implementation example | |
| def display_optimization_trajectory(trajectory, viz_style='ball-stick'): | |
| """ | |
| Display optimization trajectory in Streamlit with controls | |
| """ | |
| if not trajectory: | |
| st.error("No trajectory data available") | |
| return | |
| st.subheader(f"Optimization Trajectory ({len(trajectory)} steps)") | |
| # Trajectory options | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| viz_mode = st.selectbox( | |
| "Visualization Mode", | |
| ["Animation", "Static with paths", "Step-by-step"], | |
| key="viz_mode" | |
| ) | |
| with col2: | |
| if viz_mode == "Static with paths": | |
| show_paths = st.checkbox("Show trajectory paths", value=True) | |
| path_color = st.selectbox("Path color", ["red", "blue", "green", "orange"], index=0) | |
| elif viz_mode == "Step-by-step": | |
| frame_idx = st.slider("Frame", 0, len(trajectory)-1, 0, key="frame_slider") | |
| # Display visualization based on mode | |
| if viz_mode == "Static with paths": | |
| opt_view = get_trajectory_viz( | |
| trajectory, | |
| style=viz_style, | |
| show_unit_cell=True, | |
| width=400, | |
| height=400, | |
| show_path=show_paths, | |
| path_color=path_color | |
| ) | |
| st.components.v1.html(opt_view._make_html(), width=400, height=400) | |
| elif viz_mode == "Animation": | |
| opt_view = get_animated_trajectory_viz( | |
| trajectory, | |
| style=viz_style, | |
| show_unit_cell=True, | |
| width=400, | |
| height=400 | |
| ) | |
| st.components.v1.html(opt_view._make_html(), width=400, height=400) | |
| elif viz_mode == "Step-by-step": | |
| opt_view = get_structure_viz2( | |
| trajectory[frame_idx], | |
| style=viz_style, | |
| show_unit_cell=True, | |
| width=400, | |
| height=400 | |
| ) | |
| st.components.v1.html(opt_view._make_html(), width=400, height=400) | |
| st.write(f"Step {frame_idx + 1} of {len(trajectory)}") | |
| def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400): | |
| xyz_str = "" | |
| xyz_str += f"{len(atoms_obj)}\n" | |
| xyz_str += "Structure\n" | |
| for atom in atoms_obj: | |
| xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" | |
| view = py3Dmol.view(width=width, height=height) | |
| view.addModel(xyz_str, "xyz") | |
| if style.lower() == 'ball-stick': | |
| view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}}) | |
| elif style.lower() == 'stick': | |
| view.setStyle({'stick': {}}) | |
| elif style.lower() == 'ball': | |
| view.setStyle({'sphere': {'scale': 0.4}}) | |
| else: | |
| view.setStyle({'stick': {'radius': 0.15}}) | |
| if show_unit_cell and atoms_obj.pbc.any(): # Check pbc.any() | |
| cell = atoms_obj.get_cell() | |
| origin = np.array([0.0, 0.0, 0.0]) | |
| if cell is not None and cell.any(): # Ensure cell is not None and not all zeros | |
| edges = [ | |
| (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]), | |
| (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]), | |
| (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]), | |
| (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]), | |
| (cell[0] + cell[1], cell[0] + cell[1] + cell[2]) | |
| ] | |
| for start, end in edges: | |
| view.addCylinder({ | |
| 'start': {'x': start[0], 'y': start[1], 'z': start[2]}, | |
| 'end': {'x': end[0], 'y': end[1], 'z': end[2]}, | |
| 'radius': 0.05, 'color': 'black', 'alpha': 0.7 | |
| }) | |
| view.zoomTo() | |
| view.setBackgroundColor('white') | |
| return view | |
| opt_log = [] # Define globally or pass around if necessary | |
| table_placeholder = st.empty() # Define globally if updated from callback | |
| def streamlit_log(opt): | |
| global opt_log, table_placeholder | |
| try: | |
| energy = opt.atoms.get_potential_energy() | |
| forces = opt.atoms.get_forces() | |
| fmax_step = np.max(np.linalg.norm(forces, axis=1)) if forces.shape[0] > 0 else 0.0 | |
| opt_log.append({ | |
| "Step": opt.nsteps, | |
| "Energy (eV)": round(energy, 6), | |
| "Fmax (eV/Å)": round(fmax_step, 6) | |
| }) | |
| df = pd.DataFrame(opt_log) | |
| table_placeholder.dataframe(df) | |
| except Exception as e: | |
| st.warning(f"Error in optimization logger: {e}") | |
| def check_atom_limit(atoms_obj, selected_model): | |
| if atoms_obj is None: | |
| return True | |
| num_atoms = len(atoms_obj) | |
| limit = MAX_ATOMS_CLOUD_UMA if ('UMA' in selected_model or 'ESEN MD' in selected_model) else MAX_ATOMS_CLOUD | |
| if num_atoms > limit: | |
| st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, exceeding the {limit} atom limit for this model on Streamlit Cloud. Please run locally for larger systems.") | |
| return False | |
| return True | |
| MACE_MODELS = { | |
| "MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", | |
| "MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", | |
| "MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", | |
| "MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", | |
| "MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", | |
| "MACE MP 0a Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", | |
| "MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model", | |
| "MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model", | |
| "MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model", | |
| "MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", # Corrected name from original code | |
| "MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", | |
| "MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", | |
| "MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", | |
| } | |
| FAIRCHEM_MODELS = { | |
| "UMA Small": "uma-s-1", | |
| "ESEN MD Direct All OMOL": "esen-md-direct-all-omol", | |
| "ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol", | |
| "ESEN SM Direct All OMOL": "esen-sm-direct-all-omol" | |
| } | |
| # Define the available ORB models | |
| ORB_MODELS = { | |
| "V3 OMAT Conservative (inf)": "orb-v3-conservative-inf-omat", | |
| "V3 OMAT Conservative (20)": "orb-v3-conservative-20-omat", | |
| "V3 OMAT Direct (inf)": "orb-v3-direct-inf-omat", | |
| "V3 OMAT Direct (20)": "orb-v3-direct-20-omat", | |
| "V3 MPA Conservative (inf)": "orb-v3-conservative-inf-mpa", | |
| "V3 MPA Conservative (20)": "orb-v3-conservative-20-mpa", | |
| "V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa", | |
| "V3 MPA Direct (20)": "orb-v3-direct-20-mpa", | |
| } | |
| # Define the available MatterSim models | |
| MATTERSIM_MODELS = { | |
| "V1 SMALL": "MatterSim-v1.0.0-1M.pth", | |
| "V1 LARGE": "MatterSim-v1.0.0-5M.pth" | |
| } | |
| SEVEN_NET_MODELS = { | |
| "7net-0": "7net-0", | |
| "7net-l3i5": "7net-l3i5", | |
| "7net-omat": "7net-omat", | |
| "7net-mf-ompa": "7net-mf-ompa" | |
| } | |
| def get_mace_model(model_path, device, selected_default_dtype): | |
| return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype) | |
| def get_fairchem_model(selected_model_name, model_path_or_name, device, selected_task_type_fc): # Renamed args to avoid conflict | |
| predictor = pretrained_mlip.get_predict_unit(model_path_or_name, device=device) | |
| if selected_model_name == "UMA Small": | |
| calc = FAIRChemCalculator(predictor, task_name=selected_task_type_fc) | |
| else: | |
| calc = FAIRChemCalculator(predictor, task_name="omol") | |
| return calc | |
| st.sidebar.markdown("## Input Options") | |
| input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"]) | |
| atoms = None | |
| if input_method == "Upload File": | |
| uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"]) | |
| if uploaded_file: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| tmp_filepath = tmp_file.name | |
| try: | |
| atoms = read(tmp_filepath) | |
| st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!") | |
| except Exception as e: | |
| st.sidebar.error(f"Error loading file: {str(e)}") | |
| finally: | |
| if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath): | |
| os.unlink(tmp_filepath) | |
| elif input_method == "Select Example": | |
| example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys())) | |
| if example_name: | |
| file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name]) | |
| try: | |
| atoms = read(file_path) | |
| st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!") | |
| except Exception as e: | |
| st.sidebar.error(f"Error loading example: {str(e)}") | |
| elif input_method == "Paste Content": | |
| file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"]) | |
| content = st.sidebar.text_area("Paste file content here:", height=200) | |
| if content: | |
| try: | |
| suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"} | |
| suffix = suffix_map.get(file_format, ".xyz") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: | |
| tmp_file.write(content.encode()) | |
| tmp_filepath = tmp_file.name | |
| atoms = read(tmp_filepath) | |
| st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!") | |
| except Exception as e: | |
| st.sidebar.error(f"Error parsing content: {str(e)}") | |
| finally: | |
| if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath): | |
| os.unlink(tmp_filepath) | |
| if atoms is not None: | |
| if not hasattr(atoms, 'info'): | |
| atoms.info = {} | |
| atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge | |
| atoms.info["spin"] = atoms.info.get("spin", 0) # Default spin (usually 2S for ASE, model might want 2S+1) | |
| st.sidebar.markdown("## Model Selection") | |
| if mattersim_available: | |
| model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET", "MatterSim"]) | |
| else: | |
| model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET"]) | |
| selected_task_type = None # For FairChem UMA | |
| if model_type == "MACE": | |
| selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys())) | |
| model_path = MACE_MODELS[selected_model] | |
| if selected_model == "MACE OMAT Medium": | |
| st.sidebar.warning("Using model under Academic Software License (ASL).") | |
| # selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64']) | |
| selected_default_dtype = 'float64' | |
| if model_type == "FairChem": | |
| selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys())) | |
| model_path = FAIRCHEM_MODELS[selected_model] | |
| if selected_model == "UMA Small": | |
| st.sidebar.warning("Meta FAIR Acceptable Use Policy applies.") | |
| selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"]) | |
| if selected_task_type == "omol" and atoms is not None: | |
| charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0)) | |
| spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S | |
| atoms.info["charge"] = charge | |
| atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity | |
| if model_type == "ORB": | |
| selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys())) | |
| model_path = ORB_MODELS[selected_model] | |
| if "omat" in selected_model: | |
| st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.") | |
| # selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64']) | |
| selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest']) | |
| if model_type == "MatterSim": | |
| selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys())) | |
| model_path = MATTERSIM_MODELS[selected_model] | |
| if model_type == "SEVEN_NET": | |
| selected_model = st.sidebar.selectbox("Select SEVENNET Model:", list(SEVEN_NET_MODELS.keys())) | |
| if selected_model == '7net-mf-ompa': | |
| selected_modal_7net = st.sidebar.selectbox("Select Modal (multi fidelity model):", ['omat24', 'mpa']) | |
| model_path = SEVEN_NET_MODELS[selected_model] | |
| if atoms is not None: | |
| if not check_atom_limit(atoms, selected_model): | |
| st.stop() # Stop execution if limit exceeded | |
| device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"], index=0 if not torch.cuda.is_available() else 1) | |
| device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu" | |
| if device == "cpu" and torch.cuda.is_available(): | |
| st.sidebar.info("GPU is available but CPU was selected.") | |
| elif device == "cpu" and not torch.cuda.is_available(): | |
| st.sidebar.info("No GPU detected. Using CPU.") | |
| st.sidebar.markdown("## Task Selection") | |
| task = st.sidebar.selectbox("Select Calculation Task:", | |
| ["Energy Calculation", | |
| "Energy + Forces Calculation", | |
| "Atomization/Cohesive Energy", # New Task Added | |
| "Geometry Optimization", | |
| "Cell + Geometry Optimization", | |
| "Vibrational Mode Analysis"]) | |
| if "Optimization" in task: | |
| st.sidebar.markdown("### Optimization Parameters") | |
| max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps | |
| fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax | |
| optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type | |
| if atoms is not None: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown('### Structure Visualization', unsafe_allow_html=True) | |
| viz_style = st.selectbox("Select Visualization Style:", | |
| ["ball-stick", | |
| "stick", | |
| "ball"]) | |
| view_3d = get_structure_viz2(atoms, style=viz_style, show_unit_cell=True, width=400, height=400) | |
| st.components.v1.html(view_3d._make_html(), width=400, height=400) | |
| st.markdown("### Structure Information") | |
| atoms_info = { | |
| "Number of Atoms": len(atoms), | |
| "Chemical Formula": atoms.get_chemical_formula(), | |
| "Periodic Boundary Conditions (PBC)": atoms.pbc.tolist(), | |
| "Cell Dimensions": np.round(atoms.cell.cellpar(),3).tolist() if atoms.pbc.any() and atoms.cell is not None and atoms.cell.any() else "No cell / Non-periodic", | |
| "Atom Types": ", ".join(sorted(list(set(atoms.get_chemical_symbols())))) | |
| } | |
| for key, value in atoms_info.items(): | |
| st.write(f"**{key}:** {value}") | |
| with col2: | |
| st.markdown('## Calculation Setup', unsafe_allow_html=True) | |
| st.markdown("### Selected Model") | |
| st.write(f"**Model Type:** {model_type}") | |
| st.write(f"**Model:** {selected_model}") | |
| if model_type == "FairChem" and selected_model == "UMA Small": | |
| st.write(f"**UMA Task Type:** {selected_task_type}") | |
| st.write(f"**Device:** {device}") | |
| st.markdown("### Selected Task") | |
| st.write(f"**Task:** {task}") | |
| if "Optimization" in task: | |
| st.write(f"**Max Steps:** {max_steps}") | |
| st.write(f"**Convergence Threshold:** {fmax} eV/Å") | |
| st.write(f"**Optimizer:** {optimizer_type}") | |
| run_calculation = st.button("Run Calculation", type="primary") | |
| if run_calculation: | |
| # Delete all the items in Session state | |
| for key in st.session_state.keys(): | |
| del st.session_state[key] | |
| results = {} | |
| #global table_placeholder # Ensure they are accessible | |
| opt_log = [] # Reset log for each run | |
| if "Optimization" in task: | |
| table_placeholder = st.empty() # Recreate placeholder for table | |
| try: | |
| torch.set_default_dtype(torch.float32) | |
| with st.spinner("Running calculation... Please wait."): | |
| calc_atoms = atoms.copy() | |
| if model_type == "MACE": | |
| # st.write("Setting up MACE calculator...") | |
| calc = get_mace_model(model_path, device, 'float32') | |
| elif model_type == "FairChem": # FairChem | |
| # st.write("Setting up FairChem calculator...") | |
| # Workaround for potential dtype issues when switching models | |
| # if device == "cpu": # Ensure torch default dtype matches if needed | |
| # torch.set_default_dtype(torch.float32) | |
| # _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call | |
| calc = get_fairchem_model(selected_model, model_path, device, selected_task_type) | |
| elif model_type == "ORB": | |
| # st.write("Setting up ORB calculator...") | |
| orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype) | |
| calc = ORBCalculator(orbff, device=device) | |
| elif model_type == "MatterSim": | |
| # st.write("Setting up MatterSim calculator...") | |
| # NOTE: Running mattersim on windows requires changing source code file | |
| # https://github.com/microsoft/mattersim/issues/112 | |
| # mattersim/datasets/utils/convertor.py: 117 | |
| # to pbc_ = np.array(structure.pbc, dtype=np.int64) | |
| calc = MatterSimCalculator(load_path=model_path, device=device) | |
| elif model_type == "SEVEN_NET": | |
| # st.write("Setting up SEVENNET calculator...") | |
| if model_path=='7net-mf-ompa': | |
| calc = SevenNetCalculator(model=model_path, modal=selected_modal_7net, device=device) | |
| else: | |
| calc = SevenNetCalculator(model=model_path, device=device) | |
| calc_atoms.calc = calc | |
| if task == "Energy Calculation": | |
| energy = calc_atoms.get_potential_energy() | |
| results["Energy"] = f"{energy:.6f} eV" | |
| elif task == "Energy + Forces Calculation": | |
| energy = calc_atoms.get_potential_energy() | |
| forces = calc_atoms.get_forces() | |
| max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0 | |
| results["Energy"] = f"{energy:.6f} eV" | |
| results["Maximum Force"] = f"{max_force:.6f} eV/Å" | |
| elif task == "Atomization/Cohesive Energy": | |
| st.write("Calculating system energy...") | |
| E_system = calc_atoms.get_potential_energy() | |
| num_atoms = len(calc_atoms) | |
| if num_atoms == 0: | |
| st.error("Cannot calculate atomization/cohesive energy for a system with zero atoms.") | |
| results["Error"] = "System has no atoms." | |
| else: | |
| atomic_numbers = calc_atoms.get_atomic_numbers() | |
| E_isolated_atoms_total = 0.0 | |
| calculation_possible = True | |
| if model_type == "FairChem": | |
| st.write("Fetching FairChem reference energies for isolated atoms...") | |
| ref_key_suffix = "_elem_refs" | |
| chosen_ref_list_name = None | |
| if selected_model == "UMA Small": | |
| if selected_task_type: | |
| chosen_ref_list_name = selected_task_type + ref_key_suffix | |
| elif "ESEN" in selected_model: | |
| chosen_ref_list_name = "omol" + ref_key_suffix | |
| if chosen_ref_list_name and chosen_ref_list_name in ELEMENT_REF_ENERGIES: | |
| ref_energies = ELEMENT_REF_ENERGIES[chosen_ref_list_name] | |
| missing_Z_refs = [] | |
| for Z_val in atomic_numbers: | |
| if Z_val > 0 and Z_val < len(ref_energies): | |
| E_isolated_atoms_total += ref_energies[Z_val] | |
| else: | |
| if Z_val not in missing_Z_refs: missing_Z_refs.append(Z_val) | |
| if missing_Z_refs: | |
| st.warning(f"Reference energy for atomic number(s) {sorted(list(set(missing_Z_refs)))} " | |
| f"not found in '{chosen_ref_list_name}' list (max Z defined: {len(ref_energies)-1}). " | |
| "These atoms are treated as having 0 reference energy.") | |
| else: | |
| st.error(f"Could not find or determine reference energy list for FairChem model: '{selected_model}' " | |
| f"and UMA task type: '{selected_task_type}'. Cannot calculate atomization/cohesive energy.") | |
| results["Error"] = "Missing FairChem reference energies." | |
| calculation_possible = False | |
| else:# == "MACE": | |
| st.write("Calculating isolated atom energies with MACE...") | |
| unique_atomic_numbers = sorted(list(set(atomic_numbers))) | |
| atom_counts = {Z_unique: np.count_nonzero(atomic_numbers == Z_unique) for Z_unique in unique_atomic_numbers} | |
| progress_text = "Calculating isolated atom energies: 0% complete" | |
| mace_progress_bar = st.progress(0, text=progress_text) | |
| for i, Z_unique in enumerate(unique_atomic_numbers): | |
| isolated_atom = Atoms(numbers=[Z_unique], cell=[20, 20, 20], pbc=False) | |
| if not hasattr(isolated_atom, 'info'): isolated_atom.info = {} | |
| isolated_atom.info["charge"] = 0 | |
| isolated_atom.info["spin"] = 0 | |
| isolated_atom.calc = calc # Use the same MACE calculator | |
| E_isolated_atom_type = isolated_atom.get_potential_energy() | |
| E_isolated_atoms_total += E_isolated_atom_type * atom_counts[Z_unique] | |
| progress_val = (i + 1) / len(unique_atomic_numbers) | |
| mace_progress_bar.progress(progress_val, text=f"Calculating isolated atom energies for Z={Z_unique}: {int(progress_val*100)}% complete") | |
| mace_progress_bar.empty() | |
| if calculation_possible: | |
| is_periodic = any(calc_atoms.pbc) | |
| if is_periodic: | |
| cohesive_E = (E_isolated_atoms_total - E_system) / num_atoms | |
| results["Cohesive Energy"] = f"{cohesive_E:.6f} eV/atom" | |
| else: | |
| atomization_E = E_isolated_atoms_total - E_system | |
| results["Atomization Energy"] = f"{atomization_E:.6f} eV" | |
| results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV" | |
| results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV" | |
| elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt | |
| is_periodic = any(calc_atoms.pbc) | |
| opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms | |
| # Create temporary trajectory file | |
| traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name | |
| if optimizer_type == "BFGS": | |
| opt = BFGS(opt_atoms_obj, trajectory=traj_filename) | |
| elif optimizer_type == "LBFGS": | |
| opt = LBFGS(opt_atoms_obj, trajectory=traj_filename) | |
| else: # FIRE | |
| opt = FIRE(opt_atoms_obj, trajectory=traj_filename) | |
| # opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly | |
| opt.attach(lambda: streamlit_log(opt), interval=1) | |
| st.write(f"Running {task.lower()}...") | |
| opt.run(fmax=fmax, steps=max_steps) | |
| energy = calc_atoms.get_potential_energy() | |
| forces = calc_atoms.get_forces() | |
| max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0 | |
| results["Final Energy"] = f"{energy:.6f} eV" | |
| results["Final Maximum Force"] = f"{max_force:.6f} eV/Å" | |
| results["Steps Taken"] = opt.get_number_of_steps() | |
| results["Converged"] = "Yes" if opt.converged() else "No" | |
| if task == "Cell + Geometry Optimization": | |
| results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4).tolist() | |
| st.success("Calculation completed successfully!") | |
| st.markdown("### Results") | |
| for key, value in results.items(): | |
| st.write(f"**{key}:** {value}") | |
| if "Optimization" in task and "Final Energy" in results: # Check if opt was successful | |
| st.markdown("### Optimized Structure") | |
| opt_view = get_structure_viz2(opt_atoms_obj, style=viz_style, show_unit_cell=True, width=400, height=400) | |
| st.components.v1.html(opt_view._make_html(), width=400, height=400) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt: | |
| if is_periodic: | |
| write(tmp_file_opt.name, calc_atoms, format="extxyz") | |
| else: | |
| write(tmp_file_opt.name, calc_atoms, format="xyz") | |
| tmp_filepath_opt = tmp_file_opt.name | |
| with open(tmp_filepath_opt, 'r') as file_opt: | |
| xyz_content_opt = file_opt.read() | |
| def show_optimized_structure_download_button(): | |
| # st.button("Release the balloons", help="Fragment rerun") | |
| # st.balloons() | |
| st.download_button( | |
| label="Download Optimized Structure (XYZ)", | |
| data=xyz_content_opt, | |
| file_name="optimized_structure.xyz", | |
| mime="chemical/x-xyz" | |
| ) | |
| show_optimized_structure_download_button() | |
| os.unlink(tmp_filepath_opt) | |
| def show_trajectory_and_controls(): | |
| from ase.io import read | |
| import py3Dmol | |
| if "traj_frames" not in st.session_state: | |
| if os.path.exists(traj_filename): | |
| try: | |
| trajectory = read(traj_filename, index=":") | |
| st.session_state.traj_frames = trajectory | |
| st.session_state.traj_index = 0 | |
| except Exception as e: | |
| st.error(f"Error reading trajectory: {e}") | |
| return | |
| # finally: | |
| # os.unlink(traj_filename) | |
| else: | |
| st.warning("Trajectory file not found.") | |
| return | |
| trajectory = st.session_state.traj_frames | |
| index = st.session_state.traj_index | |
| st.markdown("### Optimization Trajectory") | |
| st.write(f"Captured {len(trajectory)} optimization steps") | |
| # Navigation Buttons | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| if st.button("⏮ First"): | |
| st.session_state.traj_index = 0 | |
| with col2: | |
| if st.button("◀ Previous") and index > 0: | |
| st.session_state.traj_index -= 1 | |
| with col3: | |
| if st.button("Next ▶") and index < len(trajectory) - 1: | |
| st.session_state.traj_index += 1 | |
| with col4: | |
| if st.button("Last ⏭"): | |
| st.session_state.traj_index = len(trajectory) - 1 | |
| # Show current frame | |
| current_atoms = trajectory[st.session_state.traj_index] | |
| st.write(f"Frame {st.session_state.traj_index + 1}/{len(trajectory)}") | |
| def atoms_to_xyz_string(atoms, step_idx=None): | |
| xyz_str = f"{len(atoms)}\n" | |
| if step_idx is not None: | |
| xyz_str += f"Step {step_idx}, Energy = {atoms.get_potential_energy():.6f} eV\n" | |
| else: | |
| xyz_str += f"Energy = {atoms.get_potential_energy():.6f} eV\n" | |
| for atom in atoms: | |
| xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" | |
| return xyz_str | |
| traj_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400) | |
| st.components.v1.html(traj_view._make_html(), width=400, height=400) | |
| # Download button for entire trajectory | |
| trajectory_xyz = "" | |
| for i, atoms in enumerate(trajectory): | |
| trajectory_xyz += atoms_to_xyz_string(atoms, i) | |
| st.download_button( | |
| label="Download Optimization Trajectory (XYZ)", | |
| data=trajectory_xyz, | |
| file_name="optimization_trajectory.xyz", | |
| mime="chemical/x-xyz" | |
| ) | |
| show_trajectory_and_controls() | |
| elif task == "Vibrational Mode Analysis": | |
| st.write("Running vibrational mode analysis using finite differences...") | |
| natoms = len(calc_atoms) | |
| is_linear = False # Set manually or auto-detect | |
| nmodes_expected = 3 * natoms - (5 if is_linear else 6) | |
| # Create temporary directory to store .vib files | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| vib = Vibrations(calc_atoms, name=os.path.join(tmpdir, 'vib')) | |
| with st.spinner("Calculating vibrational modes... This may take a few minutes."): | |
| vib.run() | |
| freqs = vib.get_frequencies() | |
| # Convert frequencies to cm⁻¹ | |
| freqs_cm = freqs #/ cm | |
| # Classify frequencies | |
| mode_data = [] | |
| for i, freq in enumerate(freqs_cm): | |
| if freq < 0: | |
| label = "Imaginary" | |
| elif abs(freq) < 500: | |
| label = "Low" | |
| else: | |
| label = "Physical" | |
| mode_data.append({ | |
| "Mode": i + 1, | |
| "Frequency (cm⁻¹)": round(freq, 2), | |
| "Type": label | |
| }) | |
| df_modes = pd.DataFrame(mode_data) | |
| # Display summary and mode count | |
| st.success("Vibrational analysis completed.") | |
| st.write(f"Number of atoms: {natoms}") | |
| st.write(f"Expected vibrational modes: {nmodes_expected}") | |
| st.write(f"Found {len(freqs_cm)} modes (including translational/rotational modes).") | |
| # Show table of modes | |
| st.write("### Vibrational Mode Summary") | |
| st.dataframe(df_modes, use_container_width=True) | |
| # Store in results dictionary | |
| results["Vibrational Modes"] = df_modes.to_dict(orient="records") | |
| # Histogram plot of vibrational frequencies | |
| st.write("### Frequency Distribution Histogram") | |
| fig, ax = plt.subplots() | |
| ax.hist(freqs_cm, bins=30, color='skyblue', edgecolor='black') | |
| ax.set_xlabel("Frequency (cm⁻¹)") | |
| ax.set_ylabel("Number of Modes") | |
| ax.set_title("Distribution of Vibrational Frequencies") | |
| st.pyplot(fig) | |
| # CSV download | |
| csv_buffer = io.StringIO() | |
| df_modes.to_csv(csv_buffer, index=False) | |
| st.download_button( | |
| label="Download Vibrational Frequencies (CSV)", | |
| data=csv_buffer.getvalue(), | |
| file_name="vibrational_modes.csv", | |
| mime="text/csv" | |
| ) | |
| except Exception as e: | |
| st.error(f"🔴 Calculation error: {str(e)}") | |
| st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).") | |
| import traceback | |
| st.error(f"Traceback: {traceback.format_exc()}") | |
| else: | |
| st.info("👋 Welcome! Please select or upload a structure using the sidebar options to begin.") | |
| st.markdown("---") | |
| with st.expander('ℹ️ About This App & Foundational MLIPs'): | |
| st.write(""" | |
| **Test, compare, and benchmark universal machine learning interatomic potentials (MLIPs).** | |
| This application allows you to perform atomistic simulations using pre-trained foundational MLIPs | |
| from the MACE and FairChem (by Meta AI) libraries. | |
| **Features:** | |
| - Upload structure files (XYZ, CIF, POSCAR, etc.) or use built-in examples. | |
| - Select from various MACE and FairChem models. | |
| - Calculate energies, forces, and perform geometry/cell optimizations. | |
| - **New**: Calculate atomization energy (for molecules) or cohesive energy (for periodic systems). | |
| - Visualize atomic structures in 3D and download results. | |
| **Quick Start:** | |
| 1. **Input**: Choose an input method in the sidebar (e.g., "Select Example"). | |
| 2. **Model**: Pick a model type (MACE/FairChem) and specific model. For FairChem UMA, select the appropriate task type (e.g., `omol` for molecules, `omat` for materials). | |
| 3. **Task**: Select a calculation task (e.g., "Energy Calculation", "Atomization/Cohesive Energy", "Geometry Optimization"). | |
| 4. **Run**: Click "Run Calculation" and view the results. | |
| **Atomization/Cohesive Energy Notes:** | |
| - **Atomization Energy** ($E_{\text{atomization}} = \sum E_{\text{isolated atoms}} - E_{\text{molecule}}$) is typically for non-periodic systems (molecules). | |
| - **Cohesive Energy** ($E_{\text{cohesive}} = (\sum E_{\text{isolated atoms}} - E_{\text{bulk system}}) / N_{\text{atoms}}$) is for periodic systems. | |
| - For **MACE models**, isolated atom energies are computed on-the-fly. | |
| - For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references. | |
| """) | |
| st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem, SevenNet, ORB and ❤️") | |
| st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) in the groups of [Prof. Ananth Govind Rajan Group](https://www.agrgroup.org/) and [Prof. Sudeep Punnathanam](https://chemeng.iisc.ac.in/sudeep/) at [IISc Bangalore](https://iisc.ac.in/)") |