Spaces:
Running
Running
| import streamlit as st | |
| import os | |
| import tempfile | |
| import torch | |
| import numpy as np | |
| from ase import Atoms | |
| from ase.io import read, write | |
| from ase.optimize import BFGS, LBFGS, FIRE | |
| from ase.constraints import FixAtoms | |
| from ase.filters import FrechetCellFilter | |
| from ase.visualize import view | |
| import py3Dmol | |
| from mace.calculators import mace_mp | |
| from fairchem.core import pretrained_mlip, FAIRChemCalculator | |
| import pandas as pd | |
| import yaml # Added for FairChem reference energies | |
| from huggingface_hub import login | |
| # try: | |
| # hf_token = st.secrets["HF_TOKEN"]["token"] | |
| # os.environ["HF_TOKEN"] = hf_token | |
| # login(token=hf_token) | |
| # except Exception as e: | |
| # print("streamlit hf secret not defined/assigned") | |
| try: | |
| hf_token = os.getenv("YOUR SECRET KEY") # Replace with your actual Hugging Face token or manage secrets appropriately | |
| if hf_token: | |
| login(token = hf_token) | |
| else: | |
| print("Hugging Face token not found. Some models might not be accessible.") | |
| except Exception as e: | |
| print(f"hf login error: {e}") | |
| os.environ["STREAMLIT_WATCHER_TYPE"] = "none" | |
| # YAML data for FairChem reference energies | |
| ELEMENT_REF_ENERGIES_YAML = """ | |
| oc20_elem_refs: | |
| - 0.0 | |
| - -0.16141512 | |
| - 0.03262098 | |
| - -0.04787699 | |
| - -0.06299825 | |
| - -0.14979306 | |
| - -0.11657468 | |
| - -0.10862579 | |
| - -0.10298174 | |
| - -0.03420248 | |
| - 0.02673997 | |
| - -0.03729558 | |
| - 0.00515243 | |
| - -0.07535697 | |
| - -0.13663351 | |
| - -0.12922852 | |
| - -0.11796547 | |
| - -0.07802946 | |
| - -0.00672682 | |
| - -0.04089589 | |
| - -0.00024177 | |
| - -1.74545186 | |
| - -1.54220241 | |
| - -1.0934019 | |
| - -1.16168372 | |
| - -1.23073475 | |
| - -0.78852824 | |
| - -0.71851599 | |
| - -0.52465053 | |
| - -0.02692092 | |
| - -0.00317922 | |
| - -0.06266862 | |
| - -0.10835274 | |
| - -0.12394474 | |
| - -0.11351727 | |
| - -0.07455817 | |
| - -0.00258354 | |
| - -0.04111325 | |
| - -0.02090265 | |
| - -1.89306078 | |
| - -1.30591887 | |
| - -0.63320009 | |
| - -0.26230344 | |
| - -0.2633669 | |
| - -0.5160055 | |
| - -0.95950798 | |
| - -1.45589361 | |
| - -0.0429969 | |
| - -0.00026949 | |
| - -0.05925609 | |
| - -0.09734631 | |
| - -0.12406852 | |
| - -0.11427538 | |
| - -0.07021442 | |
| - 0.01091345 | |
| - -0.05305289 | |
| - -0.02427209 | |
| - -0.19975668 | |
| - -1.71692859 | |
| - -1.53677781 | |
| - -3.89987009 | |
| - -10.70940462 | |
| - -6.71693816 | |
| - -0.28102249 | |
| - -8.86944824 | |
| - -7.95762687 | |
| - -7.13041437 | |
| - -6.64620014 | |
| - -5.11482482 | |
| - -4.42548227 | |
| - 0.00848295 | |
| - -0.06956227 | |
| - -2.6748853 | |
| - -2.21153293 | |
| - -1.67367741 | |
| - -1.07636151 | |
| - -0.79009981 | |
| - -0.16387243 | |
| - -0.18164401 | |
| - -0.04122529 | |
| - -0.00041833 | |
| - -0.05259382 | |
| - -0.0934314 | |
| - -0.11023834 | |
| - -0.10039175 | |
| - -0.06069209 | |
| - 0.01790437 | |
| - -0.04694024 | |
| - 0.00334084 | |
| - -0.06030621 | |
| - -0.58793619 | |
| - -1.27821808 | |
| - -4.97483577 | |
| - -5.66985655 | |
| - -8.43154622 | |
| - -11.15001317 | |
| - -12.95770812 | |
| - 0.0 | |
| - -14.47602729 | |
| - 0.0 | |
| odac_elem_refs: | |
| - 0.0 | |
| - -1.11737936 | |
| - -0.00011835 | |
| - -0.2941727 | |
| - -0.03868426 | |
| - -0.34862832 | |
| - -1.31552566 | |
| - -3.12457285 | |
| - -1.6052078 | |
| - -0.49653389 | |
| - -0.01137327 | |
| - -0.21957281 | |
| - -0.0008343 | |
| - -0.2750172 | |
| - -0.88417265 | |
| - -1.887378 | |
| - -0.94903558 | |
| - -0.31628167 | |
| - -0.02014536 | |
| - -0.15901053 | |
| - -0.00731884 | |
| - -1.96521355 | |
| - -1.89045209 | |
| - -2.53057428 | |
| - -5.43600675 | |
| - -5.09739336 | |
| - -3.03088746 | |
| - -1.23786562 | |
| - -0.40650749 | |
| - -0.2416017 | |
| - -0.01139188 | |
| - -0.26282496 | |
| - -0.82446455 | |
| - -1.70237206 | |
| - -0.84245376 | |
| - -0.28544892 | |
| - -0.02239991 | |
| - -0.14115912 | |
| - -0.02840799 | |
| - -2.09540994 | |
| - -1.85863996 | |
| - -1.12257399 | |
| - -4.32965355 | |
| - -3.30670045 | |
| - -1.19460755 | |
| - -1.26257601 | |
| - -1.46832888 | |
| - -0.19779414 | |
| - -0.0144274 | |
| - -0.23668767 | |
| - -0.70836953 | |
| - -1.43186113 | |
| - -0.71701186 | |
| - -0.24883129 | |
| - -0.01118184 | |
| - -0.13173447 | |
| - -0.0318395 | |
| - -0.41195547 | |
| - -1.23134873 | |
| - -2.03082996 | |
| - 0.1375954 | |
| - -5.45866275 | |
| - -7.59139905 | |
| - -5.99965965 | |
| - -8.43495767 | |
| - -2.6578407 | |
| - -7.77349787 | |
| - -5.30762201 | |
| - -5.15109657 | |
| - -4.41466995 | |
| - -0.02995219 | |
| - -0.2544495 | |
| - -3.23821202 | |
| - -3.45887214 | |
| - -4.53635003 | |
| - -4.60979468 | |
| - -2.90707964 | |
| - -1.28286153 | |
| - -0.57716664 | |
| - -0.18337108 | |
| - -0.01135944 | |
| - -0.22045398 | |
| - -0.66150479 | |
| - -1.32506342 | |
| - -0.66500178 | |
| - -0.22643927 | |
| - -0.00728197 | |
| - -0.11208472 | |
| - -0.00757856 | |
| - -0.21798637 | |
| - -0.91078787 | |
| - -1.78187161 | |
| - -3.89912261 | |
| - -3.94192659 | |
| - -7.59026042 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| omat_elem_refs: | |
| - 0.0 | |
| - -1.11700253 | |
| - 0.00079886 | |
| - -0.29731164 | |
| - -0.04129868 | |
| - -0.29106192 | |
| - -1.27751531 | |
| - -3.12342715 | |
| - -1.54797136 | |
| - -0.43969356 | |
| - -0.01250908 | |
| - -0.22855413 | |
| - -0.00943179 | |
| - -0.21707638 | |
| - -0.82619133 | |
| - -1.88667434 | |
| - -0.89093583 | |
| - -0.25816211 | |
| - -0.02414768 | |
| - -0.17662425 | |
| - -0.02568319 | |
| - -2.13001165 | |
| - -2.38688845 | |
| - -3.55934233 | |
| - -5.44700879 | |
| - -5.14749562 | |
| - -3.30662847 | |
| - -1.42167737 | |
| - -0.63181379 | |
| - -0.23449167 | |
| - -0.01146636 | |
| - -0.21291259 | |
| - -0.77939897 | |
| - -1.70148487 | |
| - -0.78386705 | |
| - -0.22690657 | |
| - -0.02245409 | |
| - -0.16092396 | |
| - -0.02798717 | |
| - -2.25685695 | |
| - -2.23690495 | |
| - -2.15347771 | |
| - -4.60251809 | |
| - -3.36416792 | |
| - -2.23062607 | |
| - -1.15550917 | |
| - -1.47553527 | |
| - -0.19918102 | |
| - -0.01475888 | |
| - -0.19767692 | |
| - -0.68005773 | |
| - -1.43073368 | |
| - -0.65790462 | |
| - -0.18915279 | |
| - -0.01179476 | |
| - -0.13507902 | |
| - -0.03056979 | |
| - -0.36017439 | |
| - -0.86279246 | |
| - -0.20573327 | |
| - -0.2734463 | |
| - -0.20046965 | |
| - -0.25444338 | |
| - -8.37972664 | |
| - -9.58424928 | |
| - -0.19466184 | |
| - -0.24860115 | |
| - -0.19531288 | |
| - -0.15401392 | |
| - -0.14577898 | |
| - -0.19655747 | |
| - -0.15645898 | |
| - -3.49380556 | |
| - -3.5317097 | |
| - -4.57108006 | |
| - -4.63425205 | |
| - -2.88247063 | |
| - -1.45679675 | |
| - -0.50290184 | |
| - -0.18521704 | |
| - -0.01123956 | |
| - -0.17483649 | |
| - -0.63132037 | |
| - -1.3248562 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - -0.24135757 | |
| - -1.04601971 | |
| - -2.04574044 | |
| - -3.84544799 | |
| - -7.28626119 | |
| - -7.3136314 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| omol_elem_refs: | |
| - 0.0 | |
| - -13.44558 | |
| - -78.82027 | |
| - -203.32564 | |
| - -398.94742 | |
| - -670.75275 | |
| - -1029.85403 | |
| - -1485.54188 | |
| - -2042.97832 | |
| - -2714.24015 | |
| - -3508.74317 | |
| - -4415.24203 | |
| - -5443.89712 | |
| - -6594.61834 | |
| - -7873.6878 | |
| - -9285.6593 | |
| - -10832.62132 | |
| - -12520.66852 | |
| - -14354.278 | |
| - -16323.54671 | |
| - -18436.47845 | |
| - -20696.18244 | |
| - -23110.5386 | |
| - -25682.99429 | |
| - -28418.37804 | |
| - -31317.92317 | |
| - -34383.42519 | |
| - -37623.46835 | |
| - -41039.92413 | |
| - -44637.38634 | |
| - -48417.14864 | |
| - -52373.87849 | |
| - -56512.76952 | |
| - -60836.14871 | |
| - -65344.28833 | |
| - -70041.24251 | |
| - -74929.56277 | |
| - -653.64777 | |
| - -833.31922 | |
| - -1038.0281 | |
| - -1273.96788 | |
| - -1542.45481 | |
| - -1850.74158 | |
| - -2193.91654 | |
| - -2577.18734 | |
| - -3004.13604 | |
| - -3477.52796 | |
| - -3997.31825 | |
| - -4563.75804 | |
| - -5171.82293 | |
| - -5828.85334 | |
| - -6535.61529 | |
| - -7291.54792 | |
| - -8099.87914 | |
| - -8962.17916 | |
| - -546.03214 | |
| - -690.6089 | |
| - -854.11237 | |
| - -12923.04096 | |
| - -14064.26124 | |
| - -15272.68689 | |
| - -16550.20551 | |
| - -17900.36515 | |
| - -19323.23406 | |
| - -20829.08848 | |
| - -22428.73258 | |
| - -24078.68008 | |
| - -25794.42097 | |
| - -27616.6819 | |
| - -29523.5526 | |
| - -31526.68012 | |
| - -33615.37779 | |
| - -1300.17791 | |
| - -1544.40924 | |
| - -1818.62298 | |
| - -2123.14417 | |
| - -2461.76028 | |
| - -2833.76287 | |
| - -3242.79895 | |
| - -3690.363 | |
| - -4174.99772 | |
| - -4691.75674 | |
| - -5245.36013 | |
| - -5838.12005 | |
| - -6469.07296 | |
| - -7140.86455 | |
| - -7854.60638 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| - 0.0 | |
| omc_elem_refs: | |
| - 0.0 | |
| - -0.02831808 | |
| - 4.512e-05 | |
| - -0.03227157 | |
| - -0.03842519 | |
| - -0.05829283 | |
| - -0.0845041 | |
| - -0.08806738 | |
| - -0.09021346 | |
| - -0.06669846 | |
| - -0.01218631 | |
| - -0.03650269 | |
| - -0.00059093 | |
| - -0.05787736 | |
| - -0.08730952 | |
| - -0.0975534 | |
| - -0.09264199 | |
| - -0.07124762 | |
| - -0.02374602 | |
| - -0.05299112 | |
| - -0.02631476 | |
| - -1.7772147 | |
| - -1.25083444 | |
| - -0.79579447 | |
| - -0.49099317 | |
| - -0.31414986 | |
| - -0.20292182 | |
| - -0.14011632 | |
| - -0.09929659 | |
| - -0.03771207 | |
| - -0.01117902 | |
| - -0.06168715 | |
| - -0.08873364 | |
| - -0.09512942 | |
| - -0.09035978 | |
| - -0.06910849 | |
| - -0.02244872 | |
| - -0.05303651 | |
| - -0.02871903 | |
| - -1.94805417 | |
| - -1.33379896 | |
| - -0.69169331 | |
| - -0.26184306 | |
| - -0.20631599 | |
| - -0.48251608 | |
| - -0.96911893 | |
| - -1.47569462 | |
| - -0.03845194 | |
| - -0.0142445 | |
| - -0.07118991 | |
| - -0.09940292 | |
| - -0.09235056 | |
| - -0.08755943 | |
| - -0.06544925 | |
| - -0.01246646 | |
| - -0.04692937 | |
| - -0.03225123 | |
| - -0.26086039 | |
| - -27.20024339 | |
| - -0.08412926 | |
| - -0.08225924 | |
| - -0.07799715 | |
| - -0.07806185 | |
| - 0.00043759 | |
| - -0.07459766 | |
| - 0.0 | |
| - -0.06842841 | |
| - -0.07758266 | |
| - -0.07025152 | |
| - -0.08055003 | |
| - -0.07118177 | |
| - -0.07159568 | |
| - -2.69202862 | |
| - -2.21926765 | |
| - -1.679756 | |
| - -1.06135075 | |
| - -0.4554231 | |
| - -0.14488432 | |
| - -0.18377098 | |
| - -0.03603118 | |
| - -0.01076585 | |
| - -0.06381411 | |
| - -0.0905623 | |
| - -0.10095787 | |
| - -0.09501217 | |
| - -0.0574478 | |
| - -0.00599173 | |
| - -0.04134751 | |
| - -0.0082683 | |
| - -0.08704692 | |
| - -0.49656425 | |
| - -5.24233138 | |
| - -2.32542606 | |
| - -4.3376616 | |
| - -5.96430676 | |
| - 0.0 | |
| - 0.0 | |
| - -0.03842519 | |
| - 0.0 | |
| - 0.0 | |
| """ | |
| try: | |
| ELEMENT_REF_ENERGIES = yaml.safe_load(ELEMENT_REF_ENERGIES_YAML) | |
| except yaml.YAMLError as e: | |
| # st.error(f"Error parsing YAML reference energies: {e}") # st objects can only be used in main script flow | |
| print(f"Error parsing YAML reference energies: {e}") | |
| ELEMENT_REF_ENERGIES = {} # Fallback | |
| # Check if running on Streamlit Cloud vs locally | |
| is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud' | |
| MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud | |
| MAX_ATOMS_CLOUD_UMA = 500 | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="Molecular Structure Analysis", | |
| page_icon="🧪", | |
| layout="wide" | |
| ) | |
| # Title and description | |
| st.markdown('## MLIP Playground', unsafe_allow_html=True) | |
| st.write('#### Run, test and compare >17 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials') | |
| st.markdown('Upload molecular structure files or select from predefined examples, then compute energies and forces using foundation models such as those from MACE or FairChem (Meta).', unsafe_allow_html=True) | |
| # Create a directory for sample structures if it doesn't exist | |
| SAMPLE_DIR = "sample_structures" | |
| os.makedirs(SAMPLE_DIR, exist_ok=True) | |
| # Dictionary of sample structures | |
| SAMPLE_STRUCTURES = { | |
| "Water": "H2O.xyz", | |
| "Methane": "CH4.xyz", | |
| "Benzene": "C6H6.xyz", | |
| "Ethane": "C2H6.xyz", | |
| "Caffeine": "caffeine.xyz", | |
| "Ibuprofen": "ibuprofen.xyz" | |
| } | |
| def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400): | |
| xyz_str = "" | |
| xyz_str += f"{len(atoms_obj)}\n" | |
| xyz_str += "Structure\n" | |
| for atom in atoms_obj: | |
| xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n" | |
| view = py3Dmol.view(width=width, height=height) | |
| view.addModel(xyz_str, "xyz") | |
| if style.lower() == 'ball_stick': | |
| view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}}) | |
| elif style.lower() == 'stick': | |
| view.setStyle({'stick': {}}) | |
| elif style.lower() == 'ball': | |
| view.setStyle({'sphere': {'scale': 0.4}}) | |
| else: | |
| view.setStyle({'stick': {'radius': 0.15}}) | |
| if show_unit_cell and atoms_obj.pbc.any(): # Check pbc.any() | |
| cell = atoms_obj.get_cell() | |
| origin = np.array([0.0, 0.0, 0.0]) | |
| if cell is not None and cell.any(): # Ensure cell is not None and not all zeros | |
| edges = [ | |
| (origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]), | |
| (cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]), | |
| (cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]), | |
| (origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]), | |
| (cell[0] + cell[1], cell[0] + cell[1] + cell[2]) | |
| ] | |
| for start, end in edges: | |
| view.addCylinder({ | |
| 'start': {'x': start[0], 'y': start[1], 'z': start[2]}, | |
| 'end': {'x': end[0], 'y': end[1], 'z': end[2]}, | |
| 'radius': 0.05, 'color': 'black', 'alpha': 0.7 | |
| }) | |
| view.zoomTo() | |
| view.setBackgroundColor('white') | |
| return view | |
| opt_log = [] # Define globally or pass around if necessary | |
| table_placeholder = st.empty() # Define globally if updated from callback | |
| def streamlit_log(opt): | |
| global opt_log, table_placeholder | |
| try: | |
| energy = opt.atoms.get_potential_energy() | |
| forces = opt.atoms.get_forces() | |
| fmax_step = np.max(np.linalg.norm(forces, axis=1)) if forces.shape[0] > 0 else 0.0 | |
| opt_log.append({ | |
| "Step": opt.nsteps, | |
| "Energy (eV)": round(energy, 6), | |
| "Fmax (eV/Å)": round(fmax_step, 6) | |
| }) | |
| df = pd.DataFrame(opt_log) | |
| table_placeholder.dataframe(df) | |
| except Exception as e: | |
| st.warning(f"Error in optimization logger: {e}") | |
| def check_atom_limit(atoms_obj, selected_model): | |
| if atoms_obj is None: | |
| return True | |
| num_atoms = len(atoms_obj) | |
| limit = MAX_ATOMS_CLOUD_UMA if ('UMA' in selected_model or 'ESEN MD' in selected_model) else MAX_ATOMS_CLOUD | |
| if num_atoms > limit: | |
| st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, exceeding the {limit} atom limit for this model on Streamlit Cloud. Please run locally for larger systems.") | |
| return False | |
| return True | |
| MACE_MODELS = { | |
| "MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model", | |
| "MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model", | |
| "MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model", | |
| "MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model", | |
| "MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model", | |
| "MACE MP 0a Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model", | |
| "MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model", | |
| "MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model", | |
| "MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model", | |
| "MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", # Corrected name from original code | |
| "MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model", | |
| "MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model", | |
| "MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model", | |
| } | |
| FAIRCHEM_MODELS = { | |
| "UMA Small": "uma-sm", | |
| "ESEN MD Direct All OMOL": "esen-md-direct-all-omol", | |
| "ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol", | |
| "ESEN SM Direct All OMOL": "esen-sm-direct-all-omol" | |
| } | |
| def get_mace_model(model_path, device, selected_default_dtype): | |
| return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype) | |
| def get_fairchem_model(selected_model_name, model_path_or_name, device, selected_task_type_fc): # Renamed args to avoid conflict | |
| predictor = pretrained_mlip.get_predict_unit(model_path_or_name, device=device) | |
| if selected_model_name == "UMA Small": | |
| calc = FAIRChemCalculator(predictor, task_name=selected_task_type_fc) | |
| else: | |
| calc = FAIRChemCalculator(predictor) | |
| return calc | |
| st.sidebar.markdown("## Input Options") | |
| input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"]) | |
| atoms = None | |
| if input_method == "Upload File": | |
| uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"]) | |
| if uploaded_file: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: | |
| tmp_file.write(uploaded_file.getvalue()) | |
| tmp_filepath = tmp_file.name | |
| try: | |
| atoms = read(tmp_filepath) | |
| st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!") | |
| except Exception as e: | |
| st.sidebar.error(f"Error loading file: {str(e)}") | |
| finally: | |
| if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath): | |
| os.unlink(tmp_filepath) | |
| elif input_method == "Select Example": | |
| example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys())) | |
| if example_name: | |
| file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name]) | |
| try: | |
| atoms = read(file_path) | |
| st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!") | |
| except Exception as e: | |
| st.sidebar.error(f"Error loading example: {str(e)}") | |
| elif input_method == "Paste Content": | |
| file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"]) | |
| content = st.sidebar.text_area("Paste file content here:", height=200) | |
| if content: | |
| try: | |
| suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"} | |
| suffix = suffix_map.get(file_format, ".xyz") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: | |
| tmp_file.write(content.encode()) | |
| tmp_filepath = tmp_file.name | |
| atoms = read(tmp_filepath) | |
| st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!") | |
| except Exception as e: | |
| st.sidebar.error(f"Error parsing content: {str(e)}") | |
| finally: | |
| if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath): | |
| os.unlink(tmp_filepath) | |
| if atoms is not None: | |
| if not hasattr(atoms, 'info'): | |
| atoms.info = {} | |
| atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge | |
| atoms.info["spin"] = atoms.info.get("spin", 0) # Default spin (usually 2S for ASE, model might want 2S+1) | |
| st.sidebar.markdown("## Model Selection") | |
| model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"]) | |
| selected_task_type = None # For FairChem UMA | |
| if model_type == "MACE": | |
| selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys())) | |
| model_path = MACE_MODELS[selected_model] | |
| if selected_model == "MACE OMAT Medium": | |
| st.sidebar.warning("Using model under Academic Software License (ASL).") | |
| selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64']) | |
| if model_type == "FairChem": | |
| selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys())) | |
| model_path = FAIRCHEM_MODELS[selected_model] | |
| if selected_model == "UMA Small": | |
| st.sidebar.warning("Meta FAIR Acceptable Use Policy applies.") | |
| selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"]) | |
| if selected_task_type == "omol" and atoms is not None: | |
| charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0)) | |
| spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S | |
| atoms.info["charge"] = charge | |
| atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity | |
| if atoms is not None: | |
| if not check_atom_limit(atoms, selected_model): | |
| st.stop() # Stop execution if limit exceeded | |
| device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"], index=0 if not torch.cuda.is_available() else 1) | |
| device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu" | |
| if device == "cpu" and torch.cuda.is_available(): | |
| st.sidebar.info("GPU is available but CPU was selected.") | |
| elif device == "cpu" and not torch.cuda.is_available(): | |
| st.sidebar.info("No GPU detected. Using CPU.") | |
| st.sidebar.markdown("## Task Selection") | |
| task = st.sidebar.selectbox("Select Calculation Task:", | |
| ["Energy Calculation", | |
| "Energy + Forces Calculation", | |
| "Atomization/Cohesive Energy", # New Task Added | |
| "Geometry Optimization", | |
| "Cell + Geometry Optimization"]) | |
| if "Optimization" in task: | |
| st.sidebar.markdown("### Optimization Parameters") | |
| max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps | |
| fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax | |
| optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type | |
| if atoms is not None: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown('### Structure Visualization', unsafe_allow_html=True) | |
| view_3d = get_structure_viz2(atoms, style='stick', show_unit_cell=True, width=400, height=400) | |
| st.components.v1.html(view_3d._make_html(), width=400, height=400) | |
| st.markdown("### Structure Information") | |
| atoms_info = { | |
| "Number of Atoms": len(atoms), | |
| "Chemical Formula": atoms.get_chemical_formula(), | |
| "Periodic Boundary Conditions (PBC)": atoms.pbc.tolist(), | |
| "Cell Dimensions": np.round(atoms.cell.cellpar(),3).tolist() if atoms.pbc.any() and atoms.cell is not None and atoms.cell.any() else "No cell / Non-periodic", | |
| "Atom Types": ", ".join(sorted(list(set(atoms.get_chemical_symbols())))) | |
| } | |
| for key, value in atoms_info.items(): | |
| st.write(f"**{key}:** {value}") | |
| with col2: | |
| st.markdown('## Calculation Setup', unsafe_allow_html=True) | |
| st.markdown("### Selected Model") | |
| st.write(f"**Model Type:** {model_type}") | |
| st.write(f"**Model:** {selected_model}") | |
| if model_type == "FairChem" and selected_model == "UMA Small": | |
| st.write(f"**UMA Task Type:** {selected_task_type}") | |
| st.write(f"**Device:** {device}") | |
| st.markdown("### Selected Task") | |
| st.write(f"**Task:** {task}") | |
| if "Optimization" in task: | |
| st.write(f"**Max Steps:** {max_steps}") | |
| st.write(f"**Convergence Threshold:** {fmax} eV/Å") | |
| st.write(f"**Optimizer:** {optimizer_type}") | |
| run_calculation = st.button("Run Calculation", type="primary") | |
| if run_calculation: | |
| results = {} | |
| #global table_placeholder # Ensure they are accessible | |
| opt_log = [] # Reset log for each run | |
| if "Optimization" in task: | |
| table_placeholder = st.empty() # Recreate placeholder for table | |
| try: | |
| with st.spinner("Running calculation... Please wait."): | |
| calc_atoms = atoms.copy() | |
| if model_type == "MACE": | |
| # st.write("Setting up MACE calculator...") | |
| calc = get_mace_model(model_path, device, selected_default_dtype) | |
| else: # FairChem | |
| # st.write("Setting up FairChem calculator...") | |
| # Workaround for potential dtype issues when switching models | |
| if device == "cpu": # Ensure torch default dtype matches if needed | |
| torch.set_default_dtype(torch.float32) | |
| _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call | |
| calc = get_fairchem_model(selected_model, model_path, device, selected_task_type) | |
| calc_atoms.calc = calc | |
| if task == "Energy Calculation": | |
| energy = calc_atoms.get_potential_energy() | |
| results["Energy"] = f"{energy:.6f} eV" | |
| elif task == "Energy + Forces Calculation": | |
| energy = calc_atoms.get_potential_energy() | |
| forces = calc_atoms.get_forces() | |
| max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0 | |
| results["Energy"] = f"{energy:.6f} eV" | |
| results["Maximum Force"] = f"{max_force:.6f} eV/Å" | |
| elif task == "Atomization/Cohesive Energy": | |
| st.write("Calculating system energy...") | |
| E_system = calc_atoms.get_potential_energy() | |
| num_atoms = len(calc_atoms) | |
| if num_atoms == 0: | |
| st.error("Cannot calculate atomization/cohesive energy for a system with zero atoms.") | |
| results["Error"] = "System has no atoms." | |
| else: | |
| atomic_numbers = calc_atoms.get_atomic_numbers() | |
| E_isolated_atoms_total = 0.0 | |
| calculation_possible = True | |
| if model_type == "FairChem": | |
| st.write("Fetching FairChem reference energies for isolated atoms...") | |
| ref_key_suffix = "_elem_refs" | |
| chosen_ref_list_name = None | |
| if selected_model == "UMA Small": | |
| if selected_task_type: | |
| chosen_ref_list_name = selected_task_type + ref_key_suffix | |
| elif "ESEN" in selected_model: | |
| chosen_ref_list_name = "omol" + ref_key_suffix | |
| if chosen_ref_list_name and chosen_ref_list_name in ELEMENT_REF_ENERGIES: | |
| ref_energies = ELEMENT_REF_ENERGIES[chosen_ref_list_name] | |
| missing_Z_refs = [] | |
| for Z_val in atomic_numbers: | |
| if Z_val > 0 and Z_val < len(ref_energies): | |
| E_isolated_atoms_total += ref_energies[Z_val] | |
| else: | |
| if Z_val not in missing_Z_refs: missing_Z_refs.append(Z_val) | |
| if missing_Z_refs: | |
| st.warning(f"Reference energy for atomic number(s) {sorted(list(set(missing_Z_refs)))} " | |
| f"not found in '{chosen_ref_list_name}' list (max Z defined: {len(ref_energies)-1}). " | |
| "These atoms are treated as having 0 reference energy.") | |
| else: | |
| st.error(f"Could not find or determine reference energy list for FairChem model: '{selected_model}' " | |
| f"and UMA task type: '{selected_task_type}'. Cannot calculate atomization/cohesive energy.") | |
| results["Error"] = "Missing FairChem reference energies." | |
| calculation_possible = False | |
| elif model_type == "MACE": | |
| st.write("Calculating isolated atom energies with MACE...") | |
| unique_atomic_numbers = sorted(list(set(atomic_numbers))) | |
| atom_counts = {Z_unique: np.count_nonzero(atomic_numbers == Z_unique) for Z_unique in unique_atomic_numbers} | |
| progress_text = "Calculating isolated atom energies: 0% complete" | |
| mace_progress_bar = st.progress(0, text=progress_text) | |
| for i, Z_unique in enumerate(unique_atomic_numbers): | |
| isolated_atom = Atoms(numbers=[Z_unique], cell=[20, 20, 20], pbc=False) | |
| if not hasattr(isolated_atom, 'info'): isolated_atom.info = {} | |
| isolated_atom.info["charge"] = 0 | |
| isolated_atom.info["spin"] = 0 | |
| isolated_atom.calc = calc # Use the same MACE calculator | |
| E_isolated_atom_type = isolated_atom.get_potential_energy() | |
| E_isolated_atoms_total += E_isolated_atom_type * atom_counts[Z_unique] | |
| progress_val = (i + 1) / len(unique_atomic_numbers) | |
| mace_progress_bar.progress(progress_val, text=f"Calculating isolated atom energies for Z={Z_unique}: {int(progress_val*100)}% complete") | |
| mace_progress_bar.empty() | |
| if calculation_possible: | |
| is_periodic = any(calc_atoms.pbc) | |
| if is_periodic: | |
| cohesive_E = (E_isolated_atoms_total - E_system) / num_atoms | |
| results["Cohesive Energy"] = f"{cohesive_E:.6f} eV/atom" | |
| else: | |
| atomization_E = E_isolated_atoms_total - E_system | |
| results["Atomization Energy"] = f"{atomization_E:.6f} eV" | |
| results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV" | |
| results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV" | |
| elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt | |
| opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms | |
| if optimizer_type == "BFGS": | |
| opt = BFGS(opt_atoms_obj) | |
| elif optimizer_type == "LBFGS": | |
| opt = LBFGS(opt_atoms_obj) | |
| else: # FIRE | |
| opt = FIRE(opt_atoms_obj) | |
| # opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly | |
| opt.attach(lambda: streamlit_log(opt), interval=1) | |
| st.write(f"Running {task.lower()}...") | |
| opt.run(fmax=fmax, steps=max_steps) | |
| energy = calc_atoms.get_potential_energy() | |
| forces = calc_atoms.get_forces() | |
| max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0 | |
| results["Final Energy"] = f"{energy:.6f} eV" | |
| results["Final Maximum Force"] = f"{max_force:.6f} eV/Å" | |
| results["Steps Taken"] = opt.get_number_of_steps() | |
| results["Converged"] = "Yes" if opt.converged() else "No" | |
| if task == "Cell + Geometry Optimization": | |
| results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4).tolist() | |
| st.success("Calculation completed successfully!") | |
| st.markdown("### Results") | |
| for key, value in results.items(): | |
| st.write(f"**{key}:** {value}") | |
| if "Optimization" in task and "Final Energy" in results: # Check if opt was successful | |
| st.markdown("### Optimized Structure") | |
| # Need get_structure_viz function that takes atoms obj | |
| def get_structure_viz_simple(atoms_obj_viz): | |
| xyz_str_viz = f"{len(atoms_obj_viz)}\nStructure\n" | |
| for atom_viz in atoms_obj_viz: | |
| xyz_str_viz += f"{atom_viz.symbol} {atom_viz.position[0]:.6f} {atom_viz.position[1]:.6f} {atom_viz.position[2]:.6f}\n" | |
| view_viz = py3Dmol.view(width=400, height=400) | |
| view_viz.addModel(xyz_str_viz, "xyz") | |
| view_viz.setStyle({'stick': {}}) | |
| if any(atoms_obj_viz.pbc): # Show cell for optimized periodic structures | |
| cell_viz = atoms_obj_viz.get_cell() | |
| if cell_viz is not None and cell_viz.any(): | |
| # Simplified cell drawing for brevity, use get_structure_viz2 if full cell needed | |
| view_viz.addUnitCell({'box': {'lx':cell_viz.lengths()[0],'ly':cell_viz.lengths()[1],'lz':cell_viz.lengths()[2], | |
| 'hx':cell_viz.cellpar()[3],'hy':cell_viz.cellpar()[4],'hz':cell_viz.cellpar()[5]}}) | |
| view_viz.zoomTo() | |
| view_viz.setBackgroundColor('white') | |
| return view_viz | |
| opt_view = get_structure_viz2(calc_atoms, style='stick', show_unit_cell=True, width=400, height=400) | |
| st.components.v1.html(opt_view._make_html(), width=400, height=400) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt: | |
| write(tmp_file_opt.name, calc_atoms, format="xyz") | |
| tmp_filepath_opt = tmp_file_opt.name | |
| with open(tmp_filepath_opt, 'r') as file_opt: | |
| xyz_content_opt = file_opt.read() | |
| st.download_button( | |
| label="Download Optimized Structure (XYZ)", | |
| data=xyz_content_opt, | |
| file_name="optimized_structure.xyz", | |
| mime="chemical/x-xyz" | |
| ) | |
| os.unlink(tmp_filepath_opt) | |
| except Exception as e: | |
| st.error(f"🔴 Calculation error: {str(e)}") | |
| st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).") | |
| import traceback | |
| st.error(f"Traceback: {traceback.format_exc()}") | |
| else: | |
| st.info("👋 Welcome! Please select or upload a structure using the sidebar options to begin.") | |
| st.markdown("---") | |
| with st.expander('ℹ️ About This App & Foundational MLIPs'): | |
| st.write(""" | |
| **Test, compare, and benchmark universal machine learning interatomic potentials (MLIPs).** | |
| This application allows you to perform atomistic simulations using pre-trained foundational MLIPs | |
| from the MACE and FairChem (by Meta AI) libraries. | |
| **Features:** | |
| - Upload structure files (XYZ, CIF, POSCAR, etc.) or use built-in examples. | |
| - Select from various MACE and FairChem models. | |
| - Calculate energies, forces, and perform geometry/cell optimizations. | |
| - **New**: Calculate atomization energy (for molecules) or cohesive energy (for periodic systems). | |
| - Visualize atomic structures in 3D and download results. | |
| **Quick Start:** | |
| 1. **Input**: Choose an input method in the sidebar (e.g., "Select Example"). | |
| 2. **Model**: Pick a model type (MACE/FairChem) and specific model. For FairChem UMA, select the appropriate task type (e.g., `omol` for molecules, `omat` for materials). | |
| 3. **Task**: Select a calculation task (e.g., "Energy Calculation", "Atomization/Cohesive Energy", "Geometry Optimization"). | |
| 4. **Run**: Click "Run Calculation" and view the results. | |
| **Atomization/Cohesive Energy Notes:** | |
| - **Atomization Energy** ($E_{\text{atomization}} = \sum E_{\text{isolated atoms}} - E_{\text{molecule}}$) is typically for non-periodic systems (molecules). | |
| - **Cohesive Energy** ($E_{\text{cohesive}} = (\sum E_{\text{isolated atoms}} - E_{\text{bulk system}}) / N_{\text{atoms}}$) is for periodic systems. | |
| - For **MACE models**, isolated atom energies are computed on-the-fly. | |
| - For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references. | |
| """) | |
| st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem and ❤️") | |
| st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) ([Fundamental AI Research (FAIR) team, Meta AI](https://ai.meta.com/research/fair/) and [Ananth Govind Rajan Group, IISc Bangalore](https://www.agrgroup.org/))") | |