PyFock-GUI / src /streamlit_app.py
ManasSharma07's picture
add orbmol model and fix issues with orb models
5c86435 verified
raw
history blame
92.9 kB
import streamlit as st
import os
import io
import tempfile
import torch
# FOR CPU only mode
#torch._dynamo.config.suppress_errors = True
# Or disable compilation entirely
# torch.backends.cudnn.enabled = False
import numpy as np
from ase import Atoms
from ase.io import read, write
from ase.optimize import BFGS, LBFGS, FIRE
from ase.optimize.basin import BasinHopping
from ase.optimize.minimahopping import MinimaHopping
from ase.units import kB
from ase.constraints import FixAtoms
from ase.filters import FrechetCellFilter
from ase.visualize import view
import py3Dmol
from mace.calculators import mace_mp
from fairchem.core import pretrained_mlip, FAIRChemCalculator
from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator
from sevenn.calculator import SevenNetCalculator
import pandas as pd
import yaml # Added for FairChem reference energies
import subprocess
import sys
import pkg_resources
from ase.vibrations import Vibrations
from mp_api.client import MPRester
import pubchempy as pcp
from io import StringIO
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.core.structure import Molecule
import matplotlib.pyplot as plt
mattersim_available = True
if mattersim_available:
from mattersim.forcefield import MatterSimCalculator
# try:
# subprocess.check_call([sys.executable, "-m", "pip", "install", "mattersim"])
# except Exception as e:
# print(f"Error during installation of mattersim: {e}")
# try:
# from mattersim.forcefield import MatterSimCalculator
# mattersim_available = True
# print("\n\n\n\n\n\n\nSuccessfully imported MatterSimCalculator.\n\n\n\n\n\n\n\n\n\n")
# except ImportError as e:
# print(f"Failed to import MatterSimCalculator: {e} \n\n\n\n\n\n\n\n")
# mattersim_available = False
# # Define version threshold
# required_version = "2.0.0"
# try:
# installed_version = pkg_resources.get_distribution("numpy").version
# if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version(required_version):
# print(f"numpy version {installed_version} >= {required_version}. Installing numpy<2.0.0...")
# subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0"])
# else:
# print(f"numpy version {installed_version} is already < {required_version}. No action needed.")
# except pkg_resources.DistributionNotFound:
# print("numpy is not installed. Installing numpy<2.0.0...")
# subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy<2.0.0"])
from huggingface_hub import login
# try:
# hf_token = st.secrets["HF_TOKEN"]["token"]
# os.environ["HF_TOKEN"] = hf_token
# login(token=hf_token)
# except Exception as e:
# print("streamlit hf secret not defined/assigned")
try:
hf_token = os.getenv("YOUR SECRET KEY") # Replace with your actual Hugging Face token or manage secrets appropriately
if hf_token:
login(token = hf_token)
else:
print("Hugging Face token not found. Some models might not be accessible.")
except Exception as e:
print(f"hf login error: {e}")
os.environ["STREAMLIT_WATCHER_TYPE"] = "none"
# YAML data for FairChem reference energies
ELEMENT_REF_ENERGIES_YAML = """
oc20_elem_refs:
- 0.0
- -0.16141512
- 0.03262098
- -0.04787699
- -0.06299825
- -0.14979306
- -0.11657468
- -0.10862579
- -0.10298174
- -0.03420248
- 0.02673997
- -0.03729558
- 0.00515243
- -0.07535697
- -0.13663351
- -0.12922852
- -0.11796547
- -0.07802946
- -0.00672682
- -0.04089589
- -0.00024177
- -1.74545186
- -1.54220241
- -1.0934019
- -1.16168372
- -1.23073475
- -0.78852824
- -0.71851599
- -0.52465053
- -0.02692092
- -0.00317922
- -0.06266862
- -0.10835274
- -0.12394474
- -0.11351727
- -0.07455817
- -0.00258354
- -0.04111325
- -0.02090265
- -1.89306078
- -1.30591887
- -0.63320009
- -0.26230344
- -0.2633669
- -0.5160055
- -0.95950798
- -1.45589361
- -0.0429969
- -0.00026949
- -0.05925609
- -0.09734631
- -0.12406852
- -0.11427538
- -0.07021442
- 0.01091345
- -0.05305289
- -0.02427209
- -0.19975668
- -1.71692859
- -1.53677781
- -3.89987009
- -10.70940462
- -6.71693816
- -0.28102249
- -8.86944824
- -7.95762687
- -7.13041437
- -6.64620014
- -5.11482482
- -4.42548227
- 0.00848295
- -0.06956227
- -2.6748853
- -2.21153293
- -1.67367741
- -1.07636151
- -0.79009981
- -0.16387243
- -0.18164401
- -0.04122529
- -0.00041833
- -0.05259382
- -0.0934314
- -0.11023834
- -0.10039175
- -0.06069209
- 0.01790437
- -0.04694024
- 0.00334084
- -0.06030621
- -0.58793619
- -1.27821808
- -4.97483577
- -5.66985655
- -8.43154622
- -11.15001317
- -12.95770812
- 0.0
- -14.47602729
- 0.0
odac_elem_refs:
- 0.0
- -1.11737936
- -0.00011835
- -0.2941727
- -0.03868426
- -0.34862832
- -1.31552566
- -3.12457285
- -1.6052078
- -0.49653389
- -0.01137327
- -0.21957281
- -0.0008343
- -0.2750172
- -0.88417265
- -1.887378
- -0.94903558
- -0.31628167
- -0.02014536
- -0.15901053
- -0.00731884
- -1.96521355
- -1.89045209
- -2.53057428
- -5.43600675
- -5.09739336
- -3.03088746
- -1.23786562
- -0.40650749
- -0.2416017
- -0.01139188
- -0.26282496
- -0.82446455
- -1.70237206
- -0.84245376
- -0.28544892
- -0.02239991
- -0.14115912
- -0.02840799
- -2.09540994
- -1.85863996
- -1.12257399
- -4.32965355
- -3.30670045
- -1.19460755
- -1.26257601
- -1.46832888
- -0.19779414
- -0.0144274
- -0.23668767
- -0.70836953
- -1.43186113
- -0.71701186
- -0.24883129
- -0.01118184
- -0.13173447
- -0.0318395
- -0.41195547
- -1.23134873
- -2.03082996
- 0.1375954
- -5.45866275
- -7.59139905
- -5.99965965
- -8.43495767
- -2.6578407
- -7.77349787
- -5.30762201
- -5.15109657
- -4.41466995
- -0.02995219
- -0.2544495
- -3.23821202
- -3.45887214
- -4.53635003
- -4.60979468
- -2.90707964
- -1.28286153
- -0.57716664
- -0.18337108
- -0.01135944
- -0.22045398
- -0.66150479
- -1.32506342
- -0.66500178
- -0.22643927
- -0.00728197
- -0.11208472
- -0.00757856
- -0.21798637
- -0.91078787
- -1.78187161
- -3.89912261
- -3.94192659
- -7.59026042
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omat_elem_refs:
- 0.0
- -1.11700253
- 0.00079886
- -0.29731164
- -0.04129868
- -0.29106192
- -1.27751531
- -3.12342715
- -1.54797136
- -0.43969356
- -0.01250908
- -0.22855413
- -0.00943179
- -0.21707638
- -0.82619133
- -1.88667434
- -0.89093583
- -0.25816211
- -0.02414768
- -0.17662425
- -0.02568319
- -2.13001165
- -2.38688845
- -3.55934233
- -5.44700879
- -5.14749562
- -3.30662847
- -1.42167737
- -0.63181379
- -0.23449167
- -0.01146636
- -0.21291259
- -0.77939897
- -1.70148487
- -0.78386705
- -0.22690657
- -0.02245409
- -0.16092396
- -0.02798717
- -2.25685695
- -2.23690495
- -2.15347771
- -4.60251809
- -3.36416792
- -2.23062607
- -1.15550917
- -1.47553527
- -0.19918102
- -0.01475888
- -0.19767692
- -0.68005773
- -1.43073368
- -0.65790462
- -0.18915279
- -0.01179476
- -0.13507902
- -0.03056979
- -0.36017439
- -0.86279246
- -0.20573327
- -0.2734463
- -0.20046965
- -0.25444338
- -8.37972664
- -9.58424928
- -0.19466184
- -0.24860115
- -0.19531288
- -0.15401392
- -0.14577898
- -0.19655747
- -0.15645898
- -3.49380556
- -3.5317097
- -4.57108006
- -4.63425205
- -2.88247063
- -1.45679675
- -0.50290184
- -0.18521704
- -0.01123956
- -0.17483649
- -0.63132037
- -1.3248562
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- -0.24135757
- -1.04601971
- -2.04574044
- -3.84544799
- -7.28626119
- -7.3136314
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omol_elem_refs:
- 0.0
- -13.44558
- -78.82027
- -203.32564
- -398.94742
- -670.75275
- -1029.85403
- -1485.54188
- -2042.97832
- -2714.24015
- -3508.74317
- -4415.24203
- -5443.89712
- -6594.61834
- -7873.6878
- -9285.6593
- -10832.62132
- -12520.66852
- -14354.278
- -16323.54671
- -18436.47845
- -20696.18244
- -23110.5386
- -25682.99429
- -28418.37804
- -31317.92317
- -34383.42519
- -37623.46835
- -41039.92413
- -44637.38634
- -48417.14864
- -52373.87849
- -56512.76952
- -60836.14871
- -65344.28833
- -70041.24251
- -74929.56277
- -653.64777
- -833.31922
- -1038.0281
- -1273.96788
- -1542.45481
- -1850.74158
- -2193.91654
- -2577.18734
- -3004.13604
- -3477.52796
- -3997.31825
- -4563.75804
- -5171.82293
- -5828.85334
- -6535.61529
- -7291.54792
- -8099.87914
- -8962.17916
- -546.03214
- -690.6089
- -854.11237
- -12923.04096
- -14064.26124
- -15272.68689
- -16550.20551
- -17900.36515
- -19323.23406
- -20829.08848
- -22428.73258
- -24078.68008
- -25794.42097
- -27616.6819
- -29523.5526
- -31526.68012
- -33615.37779
- -1300.17791
- -1544.40924
- -1818.62298
- -2123.14417
- -2461.76028
- -2833.76287
- -3242.79895
- -3690.363
- -4174.99772
- -4691.75674
- -5245.36013
- -5838.12005
- -6469.07296
- -7140.86455
- -7854.60638
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
- 0.0
omc_elem_refs:
- 0.0
- -0.02831808
- 4.512e-05
- -0.03227157
- -0.03842519
- -0.05829283
- -0.0845041
- -0.08806738
- -0.09021346
- -0.06669846
- -0.01218631
- -0.03650269
- -0.00059093
- -0.05787736
- -0.08730952
- -0.0975534
- -0.09264199
- -0.07124762
- -0.02374602
- -0.05299112
- -0.02631476
- -1.7772147
- -1.25083444
- -0.79579447
- -0.49099317
- -0.31414986
- -0.20292182
- -0.14011632
- -0.09929659
- -0.03771207
- -0.01117902
- -0.06168715
- -0.08873364
- -0.09512942
- -0.09035978
- -0.06910849
- -0.02244872
- -0.05303651
- -0.02871903
- -1.94805417
- -1.33379896
- -0.69169331
- -0.26184306
- -0.20631599
- -0.48251608
- -0.96911893
- -1.47569462
- -0.03845194
- -0.0142445
- -0.07118991
- -0.09940292
- -0.09235056
- -0.08755943
- -0.06544925
- -0.01246646
- -0.04692937
- -0.03225123
- -0.26086039
- -27.20024339
- -0.08412926
- -0.08225924
- -0.07799715
- -0.07806185
- 0.00043759
- -0.07459766
- 0.0
- -0.06842841
- -0.07758266
- -0.07025152
- -0.08055003
- -0.07118177
- -0.07159568
- -2.69202862
- -2.21926765
- -1.679756
- -1.06135075
- -0.4554231
- -0.14488432
- -0.18377098
- -0.03603118
- -0.01076585
- -0.06381411
- -0.0905623
- -0.10095787
- -0.09501217
- -0.0574478
- -0.00599173
- -0.04134751
- -0.0082683
- -0.08704692
- -0.49656425
- -5.24233138
- -2.32542606
- -4.3376616
- -5.96430676
- 0.0
- 0.0
- -0.03842519
- 0.0
- 0.0
"""
try:
ELEMENT_REF_ENERGIES = yaml.safe_load(ELEMENT_REF_ENERGIES_YAML)
except yaml.YAMLError as e:
# st.error(f"Error parsing YAML reference energies: {e}") # st objects can only be used in main script flow
print(f"Error parsing YAML reference energies: {e}")
ELEMENT_REF_ENERGIES = {} # Fallback
# Check if running on Streamlit Cloud vs locally
is_streamlit_cloud = os.environ.get('STREAMLIT_RUNTIME_ENV') == 'cloud'
MAX_ATOMS_CLOUD = 500 # Maximum atoms allowed on Streamlit Cloud
MAX_ATOMS_CLOUD_UMA = 500
# Set page configuration
st.set_page_config(
page_title="MLIP Playground - Run, Test and Benchmark MLIPs",
page_icon="🧪",
layout="wide"
)
# Title and description
st.markdown('## MLIP Playground', unsafe_allow_html=True)
st.write('#### Run, test and compare 22 state-of-the-art universal machine learning interatomic potentials (MLIPs) for atomistic simulations of molecules and materials')
st.markdown('Upload molecular structure files or select from predefined examples, then compute energies and forces using foundation models such as those from MACE or FairChem (Meta).', unsafe_allow_html=True)
# Create a directory for sample structures if it doesn't exist
SAMPLE_DIR = "sample_structures"
os.makedirs(SAMPLE_DIR, exist_ok=True)
# Dictionary of sample structures
SAMPLE_STRUCTURES = {
"Water": "H2O.xyz",
"Methane": "CH4.xyz",
"Benzene": "C6H6.xyz",
"Ethane": "C2H6.xyz",
"Caffeine": "caffeine.xyz",
"Ibuprofen": "ibuprofen.xyz",
"Silicon": "Si.cif",
"hBN Monolayer (4x4)": "hBN_monolayer_4x4_supercell.extxyz",
}
def get_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400,
show_path=True, path_color='red', path_radius=0.02):
"""
Visualize optimization trajectory with multiple frames
Args:
trajectory: List of ASE atoms objects representing the optimization steps
style: Visualization style ('stick', 'ball', 'ball-stick')
show_unit_cell: Whether to show unit cell
show_path: Whether to show trajectory paths for each atom
path_color: Color of trajectory paths
path_radius: Radius of trajectory path cylinders
"""
if not trajectory:
return None
view = py3Dmol.view(width=width, height=height)
# Add all frames to the viewer
for frame_idx, atoms_obj in enumerate(trajectory):
xyz_str = ""
xyz_str += f"{len(atoms_obj)}\n"
xyz_str += f"Frame {frame_idx}\n"
for atom in atoms_obj:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
view.addModel(xyz_str, "xyz")
# Set style for all models
if style.lower() == 'ball-stick':
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
elif style.lower() == 'stick':
view.setStyle({'stick': {}})
elif style.lower() == 'ball':
view.setStyle({'sphere': {'scale': 0.4}})
else:
view.setStyle({'stick': {'radius': 0.15}})
# Add trajectory paths
if show_path and len(trajectory) > 1:
for atom_idx in range(len(trajectory[0])):
for frame_idx in range(len(trajectory) - 1):
start_pos = trajectory[frame_idx][atom_idx].position
end_pos = trajectory[frame_idx + 1][atom_idx].position
view.addCylinder({
'start': {'x': start_pos[0], 'y': start_pos[1], 'z': start_pos[2]},
'end': {'x': end_pos[0], 'y': end_pos[1], 'z': end_pos[2]},
'radius': path_radius,
'color': path_color,
'alpha': 0.5
})
# Add unit cell for the last frame
if show_unit_cell and trajectory[-1].pbc.any():
cell = trajectory[-1].get_cell()
origin = np.array([0.0, 0.0, 0.0])
if cell is not None and cell.any():
edges = [
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]),
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
(cell[0] + cell[1], cell[0] + cell[1] + cell[2])
]
for start, end in edges:
view.addCylinder({
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
'radius': 0.05, 'color': 'black', 'alpha': 0.7
})
view.zoomTo()
view.setBackgroundColor('white')
return view
def get_animated_trajectory_viz(trajectory, style='stick', show_unit_cell=True, width=400, height=400):
"""
Create an animated trajectory visualization
"""
if not trajectory:
return None
view = py3Dmol.view(width=width, height=height)
# Add all frames
for frame_idx, atoms_obj in enumerate(trajectory):
xyz_str = ""
xyz_str += f"{len(atoms_obj)}\n"
xyz_str += f"Frame {frame_idx}\n"
for atom in atoms_obj:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
view.addModel(xyz_str, "xyz")
# Set style
if style.lower() == 'ball-stick':
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
elif style.lower() == 'stick':
view.setStyle({'stick': {}})
elif style.lower() == 'ball':
view.setStyle({'sphere': {'scale': 0.4}})
else:
view.setStyle({'stick': {'radius': 0.15}})
# Add unit cell for last frame
if show_unit_cell and trajectory[-1].pbc.any():
cell = trajectory[-1].get_cell()
origin = np.array([0.0, 0.0, 0.0])
if cell is not None and cell.any():
edges = [
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
(cell[0] + cell[1], cell[0] + cell[1] + cell[2]),
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1])
]
for start, end in edges:
view.addCylinder({
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
'radius': 0.05, 'color': 'black', 'alpha': 0.7
})
view.zoomTo()
view.setBackgroundColor('white')
# Enable animation
view.animate({'loop': 'forward', 'reps': 0, 'interval': 500})
return view
# Streamlit implementation example
def display_optimization_trajectory(trajectory, viz_style='ball-stick'):
"""
Display optimization trajectory in Streamlit with controls
"""
if not trajectory:
st.error("No trajectory data available")
return
st.subheader(f"Optimization Trajectory ({len(trajectory)} steps)")
# Trajectory options
col1, col2 = st.columns(2)
with col1:
viz_mode = st.selectbox(
"Visualization Mode",
["Animation", "Static with paths", "Step-by-step"],
key="viz_mode"
)
with col2:
if viz_mode == "Static with paths":
show_paths = st.checkbox("Show trajectory paths", value=True)
path_color = st.selectbox("Path color", ["red", "blue", "green", "orange"], index=0)
elif viz_mode == "Step-by-step":
frame_idx = st.slider("Frame", 0, len(trajectory)-1, 0, key="frame_slider")
# Display visualization based on mode
if viz_mode == "Static with paths":
opt_view = get_trajectory_viz(
trajectory,
style=viz_style,
show_unit_cell=True,
width=400,
height=400,
show_path=show_paths,
path_color=path_color
)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
elif viz_mode == "Animation":
opt_view = get_animated_trajectory_viz(
trajectory,
style=viz_style,
show_unit_cell=True,
width=400,
height=400
)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
elif viz_mode == "Step-by-step":
opt_view = get_structure_viz2(
trajectory[frame_idx],
style=viz_style,
show_unit_cell=True,
width=400,
height=400
)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
st.write(f"Step {frame_idx + 1} of {len(trajectory)}")
def get_structure_viz2(atoms_obj, style='stick', show_unit_cell=True, width=400, height=400):
xyz_str = ""
xyz_str += f"{len(atoms_obj)}\n"
xyz_str += "Structure\n"
for atom in atoms_obj:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
view = py3Dmol.view(width=width, height=height)
view.addModel(xyz_str, "xyz")
if style.lower() == 'ball-stick':
view.setStyle({'stick': {'radius': 0.1}, 'sphere': {'scale': 0.3}})
elif style.lower() == 'stick':
view.setStyle({'stick': {}})
elif style.lower() == 'ball':
view.setStyle({'sphere': {'scale': 0.4}})
else:
view.setStyle({'stick': {'radius': 0.15}})
if show_unit_cell and atoms_obj.pbc.any(): # Check pbc.any()
cell = atoms_obj.get_cell()
origin = np.array([0.0, 0.0, 0.0])
if cell is not None and cell.any(): # Ensure cell is not None and not all zeros
edges = [
(origin, cell[0]), (origin, cell[1]), (cell[0], cell[0] + cell[1]), (cell[1], cell[0] + cell[1]),
(cell[2], cell[2] + cell[0]), (cell[2], cell[2] + cell[1]),
(cell[2] + cell[0], cell[2] + cell[0] + cell[1]), (cell[2] + cell[1], cell[2] + cell[0] + cell[1]),
(origin, cell[2]), (cell[0], cell[0] + cell[2]), (cell[1], cell[1] + cell[2]),
(cell[0] + cell[1], cell[0] + cell[1] + cell[2])
]
for start, end in edges:
view.addCylinder({
'start': {'x': start[0], 'y': start[1], 'z': start[2]},
'end': {'x': end[0], 'y': end[1], 'z': end[2]},
'radius': 0.05, 'color': 'black', 'alpha': 0.7
})
view.zoomTo()
view.setBackgroundColor('white')
return view
opt_log = [] # Define globally or pass around if necessary
table_placeholder = st.empty() # Define globally if updated from callback
def streamlit_log(opt):
global opt_log, table_placeholder
try:
energy = opt.atoms.get_potential_energy()
forces = opt.atoms.get_forces()
fmax_step = np.max(np.linalg.norm(forces, axis=1)) if forces.shape[0] > 0 else 0.0
opt_log.append({
"Step": opt.nsteps,
"Energy (eV)": round(energy, 6),
"Fmax (eV/Å)": round(fmax_step, 6)
})
df = pd.DataFrame(opt_log)
table_placeholder.dataframe(df)
except Exception as e:
st.warning(f"Error in optimization logger: {e}")
def check_atom_limit(atoms_obj, selected_model):
if atoms_obj is None:
return True
num_atoms = len(atoms_obj)
limit = MAX_ATOMS_CLOUD_UMA if ('UMA' in selected_model or 'ESEN MD' in selected_model) else MAX_ATOMS_CLOUD
if num_atoms > limit:
st.error(f"⚠️ Error: Your structure contains {num_atoms} atoms, exceeding the {limit} atom limit for this model on Streamlit Cloud. Please run locally for larger systems.")
return False
return True
MACE_MODELS = {
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
"MACE OMAT Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-small.model",
"MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
"MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
"MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
"MACE MP 0a Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-03-mace-128-L1_epoch-199.model",
"MACE MP 0a Large": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2024-01-07-mace-128-L2_epoch-199.model",
"MACE MP 0b Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_small.model",
"MACE MP 0b Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b/mace_agnesi_medium.model",
"MACE MP 0b2 Small": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-small-density-agnesi-stress.model", # Corrected name from original code
"MACE MP 0b2 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-medium-density-agnesi-stress.model",
"MACE MP 0b2 Large": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b2/mace-large-density-agnesi-stress.model",
"MACE MP 0b3 Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mp_0b3/mace-mp-0b3-medium.model",
"MACE ANI-CC Large (500k)": "https://github.com/ACEsuit/mace/raw/main/mace/calculators/foundations_models/ani500k_large_CC.model",
"MACE OMOL-0 XL 4M": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_omol_0/mace-omol-0-extra-large-4M.model",
"MACE OMOL-0 XL 1024": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_omol_0/MACE-omol-0-extra-large-1024.model"
}
FAIRCHEM_MODELS = {
"UMA Small 1": "uma-s-1",
"UMA Small 1.1": "uma-s-1p1",
"ESEN MD Direct All OMOL": "esen-md-direct-all-omol",
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
}
# Define the available ORB models
ORB_MODELS = {
"V3 OMOL Conservative": pretrained.orb_v3_conservative_omol,
"V3 OMOL Direct": pretrained.orb_v3_direct_omol,
"V3 OMAT Conservative (inf)": pretrained.orb_v3_conservative_inf_omat,
"V3 OMAT Conservative (20)": pretrained.orb_v3_conservative_20_omat,
"V3 OMAT Direct (inf)": pretrained.orb_v3_direct_inf_omat,
"V3 OMAT Direct (20)": pretrained.orb_v3_direct_20_omat,
"V3 MPA Conservative (inf)": pretrained.orb_v3_conservative_inf_mpa,
"V3 MPA Conservative (20)": pretrained.orb_v3_conservative_20_mpa,
"V3 MPA Direct (inf)": pretrained.orb_v3_direct_inf_mpa,
"V3 MPA Direct (20)": pretrained.orb_v3_direct_20_mpa,
}
# Define the available MatterSim models
MATTERSIM_MODELS = {
"V1 SMALL": "MatterSim-v1.0.0-1M.pth",
"V1 LARGE": "MatterSim-v1.0.0-5M.pth"
}
SEVEN_NET_MODELS = {
"7net-0": "7net-0",
"7net-l3i5": "7net-l3i5",
"7net-omat": "7net-omat",
"7net-mf-ompa": "7net-mf-ompa"
}
@st.cache_resource
def get_mace_model(model_path, dispersion, device, selected_default_dtype):
return mace_mp(model=model_path, dispersion=dispersion, device=device, default_dtype=selected_default_dtype)
@st.cache_resource
def get_fairchem_model(selected_model_name, model_path_or_name, device, selected_task_type_fc): # Renamed args to avoid conflict
predictor = pretrained_mlip.get_predict_unit(model_path_or_name, device=device)
if "UMA Small" in selected_model_name:
calc = FAIRChemCalculator(predictor, task_name=selected_task_type_fc)
else:
calc = FAIRChemCalculator(predictor, task_name="omol")
return calc
# --- INITIALIZATION (Must be run first) ---
if "atoms" not in st.session_state:
st.session_state.atoms = None
if "atoms_list" not in st.session_state:
st.session_state.atoms_list = []
# Reset atoms state if input method changes, to prevent using old data
# Use a key to track the currently active input method
if 'current_input_method' not in st.session_state:
st.session_state.current_input_method = "Select Example"
st.sidebar.markdown("## Input Options")
input_method = st.sidebar.radio("Choose Input Method:",
["Select Example", "Upload File", "Paste Content", "Materials Project ID", "PubChem", "Batch Upload"])
# If the input method changes, clear the loaded structure
if input_method != st.session_state.current_input_method:
st.session_state.atoms = None
st.session_state.current_input_method = input_method
# --- UPLOAD FILE ---
if input_method == "Upload File":
uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
# Load immediately upon file upload/change (no button needed)
if uploaded_file:
try:
# Check if this file content has already been loaded to prevent redundant temp file operations
if 'uploaded_file_hash' not in st.session_state or st.session_state.uploaded_file_hash != uploaded_file.name:
# Use tempfile to handle the uploaded file content
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_filepath = tmp_file.name
atoms_to_store = read(tmp_filepath)
st.session_state.atoms = atoms_to_store
st.session_state.uploaded_file_hash = uploaded_file.name # Track the loaded file
st.sidebar.success(f"Successfully loaded structure with {len(atoms_to_store)} atoms!")
except Exception as e:
st.sidebar.error(f"Error loading file: {str(e)}")
st.session_state.atoms = None
st.session_state.uploaded_file_hash = None # Clear hash on failure
finally:
# Clean up the temporary file
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
os.unlink(tmp_filepath)
else:
# Clear structure if file uploader is empty
st.session_state.atoms = None
# --- SELECT EXAMPLE ---
elif input_method == "Select Example":
# Load immediately upon selection change (no button needed)
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
# Only load if a valid example is selected and it's different from the current state
if example_name and (st.session_state.atoms is None or st.session_state.atoms.info.get('source_name') != example_name):
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
try:
atoms_to_store = read(file_path)
atoms_to_store.info['source_name'] = example_name # Add a tag for tracking
st.session_state.atoms = atoms_to_store
st.sidebar.success(f"Loaded {example_name} with {len(atoms_to_store)} atoms!")
except Exception as e:
st.sidebar.error(f"Error loading example: {str(e)}")
st.session_state.atoms = None
# --- PASTE CONTENT ---
elif input_method == "Paste Content":
file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
content = st.sidebar.text_area("Paste file content here:", height=200, key="paste_content_input")
# Load immediately upon content change (no button needed)
# Check if content is present and is different from the last successfully parsed content
if content:
# Simple check to avoid parsing on every single character change
if 'last_parsed_content' not in st.session_state or st.session_state.last_parsed_content != content:
try:
suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
suffix = suffix_map.get(file_format, ".xyz")
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(content.encode())
tmp_filepath = tmp_file.name
atoms_to_store = read(tmp_filepath)
st.session_state.atoms = atoms_to_store
st.session_state.last_parsed_content = content # Track the parsed content
st.sidebar.success(f"Successfully parsed structure with {len(atoms_to_store)} atoms!")
except Exception as e:
st.sidebar.error(f"Error parsing content: {str(e)}")
st.session_state.atoms = None
st.session_state.last_parsed_content = None
finally:
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
os.unlink(tmp_filepath)
else:
# Clear structure if text area is empty
st.session_state.atoms = None
# --- PUBCHEM SEARCH MODE ---
elif input_method == "PubChem":
st.sidebar.markdown("### Search PubChem")
query = st.sidebar.text_input("Enter name or formula (e.g., H2O, water, methane):",
key="pubchem_query", value="water")
# Reset atoms if no query
if query.strip() == "":
st.session_state.atoms = None
# Step 1: Search PubChem
if query and query.strip():
# Avoid re-searching if query is unchanged
if "pubchem_last_query" not in st.session_state or st.session_state.pubchem_last_query != query:
try:
with st.spinner("Searching PubChem..."):
results = pcp.get_compounds(query, "name") # name OR formula works
st.session_state.pubchem_results = results
st.session_state.pubchem_last_query = query
except Exception as e:
st.sidebar.error(f"Error searching PubChem: {str(e)}")
st.session_state.pubchem_results = None
results = st.session_state.get("pubchem_results", [])
if results:
# Convert to displayable table
df = pd.DataFrame(
[(c.cid, c.iupac_name, c.molecular_formula, c.molecular_weight, c.isomeric_smiles)
for c in results],
columns=["CID", "Name", "Formula", "Weight", "SMILES"]
)
st.sidebar.success(f"Found {len(df)} result(s).")
st.sidebar.dataframe(df)
# Choose a CID
cid = st.sidebar.selectbox("Select CID", df["CID"], key="pubchem_cid")
# Step 2: Retrieve 3D structure for selected CID
if cid:
if "pubchem_last_cid" not in st.session_state or st.session_state.pubchem_last_cid != cid:
try:
with st.spinner("Fetching 3D coordinates..."):
# Function to format floating-point numbers with alignment
def format_number(num, width=10, precision=5):
# Handles positive/negative numbers while maintaining alignment
return f"{num: {width}.{precision}f}"
# CID to XYZ
def generate_xyz_coordinates(cid):
compound = pcp.Compound.from_cid(cid, record_type='3d')
atoms = compound.atoms
coords = [(atom.x, atom.y, atom.z) for atom in atoms]
num_atoms = len(atoms)
xyz_text = f"{num_atoms}\n{compound.cid}\n"
for atom, coord in zip(atoms, coords):
atom_symbol = atom.element
x, y, z = coord
xyz_text += f"{atom_symbol} {format_number(x, precision=8)} {format_number(y, precision=8)} {format_number(z, precision=8)}\n"
return xyz_text
def get_molecule(cid):
xyz_str = generate_xyz_coordinates(cid)
return Molecule.from_str(xyz_str, fmt='xyz'), xyz_str
# Fetch SDF with 3D conformer
# sdf_str = pcp.Compound.from_cid(int(cid)).to_sdf()
selected_molecule, xyz_str = get_molecule(cid)
# Convert SDF → ASE Atoms using temporary memory buffer
atoms_to_store = read(StringIO(xyz_str), format="xyz")
atoms_to_store.info["source_name"] = f"PubChem CID {cid}"
st.session_state.atoms = atoms_to_store
st.session_state.pubchem_last_cid = cid
st.sidebar.success(f"Loaded PubChem structure with {len(atoms_to_store)} atoms!")
except Exception as e:
st.sidebar.error(f"Unable to retrieve 3D structure: {str(e)}")
st.session_state.atoms = None
st.session_state.pubchem_last_cid = None
else:
st.sidebar.info("No PubChem results found.")
# --- MATERIALS PROJECT ID ---
elif input_method == "Materials Project ID":
mp_api_key = os.getenv("MP_API_KEY")
material_id = st.sidebar.text_input("Enter Material ID:", value="mp-149", key="mp_id_input")
cell_type = st.sidebar.radio("Unit Cell Type:", ['Primitive Cell', 'Conventional Unit Cell'], key="cell_type_radio")
# Reactive Loading (No button needed)
# Check for valid inputs and if the current material_id/cell_type is different from the loaded one
if mp_api_key and material_id:
# Simple tracking to avoid API call if nothing has changed
current_mp_key = f"{material_id}_{cell_type}"
if 'last_fetched_mp_key' not in st.session_state or st.session_state.last_fetched_mp_key != current_mp_key:
try:
with st.spinner(f"Fetching {material_id}..."):
with MPRester(mp_api_key) as mpr:
pmg_structure = mpr.get_structure_by_material_id(material_id)
analyzer = SpacegroupAnalyzer(pmg_structure)
if cell_type == 'Conventional Unit Cell':
final_structure = analyzer.get_conventional_standard_structure()
else:
final_structure = analyzer.get_primitive_standard_structure()
atoms_to_store = AseAtomsAdaptor.get_atoms(final_structure)
st.session_state.atoms = atoms_to_store
st.session_state.last_fetched_mp_key = current_mp_key # Update tracking key
st.sidebar.success(f"Loaded {material_id} ({cell_type}) with {len(st.session_state.atoms)} atoms.")
except Exception as e:
st.sidebar.error(f"Error fetching data: {str(e)}")
st.session_state.atoms = None
st.session_state.last_fetched_mp_key = None # Clear key on failure
# Handle error messages when inputs are missing
elif not mp_api_key:
st.sidebar.error("Please set your Materials Project API Key (MP_API_KEY environment variable).")
elif not material_id:
st.sidebar.error("Please enter a Material ID.")
# --- BATCH UPLOAD MULTIPLE FILES ---
elif input_method == "Batch Upload":
uploaded_files = st.sidebar.file_uploader(
"Upload multiple structure files",
type=["xyz", "cif", "POSCAR", "vasp", "CONTCAR", "mol", "sdf", "tmol", "extxyz"],
accept_multiple_files=True
)
# Clear state if no files present
if not uploaded_files:
st.session_state.atoms_list = []
st.session_state.atoms = None
else:
atoms_list = []
errors = []
for file in uploaded_files:
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[1]) as tmp:
tmp.write(file.getvalue())
tmp_path = tmp.name
atoms_obj = read(tmp_path)
atoms_obj.info["source_name"] = file.name
atoms_list.append(atoms_obj)
except Exception as e:
errors.append(f"{file.name}: {str(e)}")
finally:
if "tmp_path" in locals() and os.path.exists(tmp_path):
os.unlink(tmp_path)
# Store everything only if at least one success
if atoms_list:
st.session_state.atoms_list = atoms_list
st.session_state.atoms = atoms_list[0] # default: first item
st.sidebar.success(f"Loaded {len(atoms_list)} structures successfully!")
if len(atoms_list) > 1:
st.sidebar.info("You can now process them as a batch.")
if errors:
st.sidebar.error("Some files could not be loaded:\n" + "\n".join(errors))
# ----------------------------------------------------
# --- FINAL STRUCTURE RETRIEVAL (The persistent structure) ---
# ----------------------------------------------------
# This is the single source of truth for the rest of your app
atoms = st.session_state.atoms
if atoms is not None:
if not hasattr(atoms, 'info'):
atoms.info = {}
atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
# Display confirmation in the main area (optional, helps the user confirm what's loaded)
st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
st.sidebar.markdown("## Model Selection")
if mattersim_available:
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET", "MatterSim"])
else:
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET"])
selected_task_type = None # For FairChem UMA
if model_type == "MACE":
selected_model = st.sidebar.selectbox("Select MACE Model:", list(MACE_MODELS.keys()))
model_path = MACE_MODELS[selected_model]
if selected_model == "MACE OMAT Medium":
st.sidebar.warning("Using model under Academic Software License (ASL).")
# selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32', 'float64'])
selected_default_dtype = 'float64'
dispersion = st.sidebar.checkbox("Dispersion correction?", value=False)
if selected_model == "MACE OMOL-0 XL 4M" or selected_model == "MACE OMOL-0 XL 1024":
charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0))
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=1, value=int(atoms.info.get("spin",0) if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
atoms.info["charge"] = charge
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
# else:
# atoms.info["charge"] = 0
# atoms.info["spin"] = 1 # FairChem expects multiplicity
if model_type == "FairChem":
selected_model = st.sidebar.selectbox("Select FairChem Model:", list(FAIRCHEM_MODELS.keys()))
model_path = FAIRCHEM_MODELS[selected_model]
if "UMA Small" in selected_model:
st.sidebar.warning("Meta FAIR Acceptable Use Policy applies.")
selected_task_type = st.sidebar.selectbox("Select UMA Model Task Type:", ["omol", "omat", "omc", "odac", "oc20"])
if selected_task_type == "omol" and atoms is not None:
charge = st.sidebar.number_input("Total Charge", min_value=-10, max_value=10, value=atoms.info.get("charge",0))
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=1, value=int(atoms.info.get("spin",0) if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
atoms.info["charge"] = charge
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
else:
atoms.info["charge"] = 0
atoms.info["spin"] = 1 # FairChem expects multiplicity
if model_type == "ORB":
selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
model_path = ORB_MODELS[selected_model]
if "omat" in selected_model:
st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
# selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest'])
if model_type == "MatterSim":
selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
model_path = MATTERSIM_MODELS[selected_model]
if model_type == "SEVEN_NET":
selected_model = st.sidebar.selectbox("Select SEVENNET Model:", list(SEVEN_NET_MODELS.keys()))
if selected_model == '7net-mf-ompa':
selected_modal_7net = st.sidebar.selectbox("Select Modal (multi fidelity model):", ['omat24', 'mpa'])
model_path = SEVEN_NET_MODELS[selected_model]
if atoms is not None:
if not check_atom_limit(atoms, selected_model):
st.stop() # Stop execution if limit exceeded
device = st.sidebar.radio("Computation Device:", ["CPU", "CUDA (GPU)"], index=0 if not torch.cuda.is_available() else 1)
device = "cuda" if device == "CUDA (GPU)" and torch.cuda.is_available() else "cpu"
if device == "cpu" and torch.cuda.is_available():
st.sidebar.info("GPU is available but CPU was selected.")
elif device == "cpu" and not torch.cuda.is_available():
st.sidebar.info("No GPU detected. Using CPU.")
st.sidebar.markdown("## Task Selection")
task = st.sidebar.selectbox("Select Calculation Task:",
["Energy Calculation",
"Energy + Forces Calculation",
"Atomization/Cohesive Energy",
"Geometry Optimization",
"Cell + Geometry Optimization",
#"Global Optimization",
"Vibrational Mode Analysis",
#"Phonons"
])
if "Optimization" in task:
# st.sidebar.markdown("### Optimization Parameters")
# max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps
# fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax
# optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type
st.sidebar.markdown("### Optimization Parameters")
# 1. Configuration for GLOBAL Optimization
if task == "Global Optimization":
global_method = st.sidebar.selectbox("Method:", ["Basin Hopping", "Minima Hopping"])
# Common parameters
temperature_K = st.sidebar.number_input("Temperature (K):", min_value=10.0, max_value=2000.0, value=300.0, step=10.0)
global_steps = st.sidebar.number_input("Search Steps:", min_value=10, max_value=500, value=50, step=10)
# Basin Hopping specific
if global_method == "Basin Hopping":
dr_amp = st.sidebar.number_input("Displacement Amplitude (Å):", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
fmax_local = st.sidebar.number_input("Local Relaxation Threshold (eV/Å):", value=0.05, format="%.3f")
# Minima Hopping specific
elif global_method == "Minima Hopping":
st.sidebar.caption("Minima Hopping automates threshold adjustments to escape local minima.")
fmax_local = st.sidebar.number_input("Local Relaxation Threshold (eV/Å):", value=0.05, format="%.3f")
# 2. Configuration for LOCAL/CELL Optimization
else:
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1)
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f")
optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1)
if "Vibration" in task:
st.write("### Thermodynamic Quantities (Molecule Only)")
T = st.sidebar.number_input("Temperature (K)", value=298.15)
if atoms is not None:
col1, col2 = st.columns(2)
with col1:
st.markdown('### Structure Visualization', unsafe_allow_html=True)
viz_style = st.selectbox("Select Visualization Style:",
["ball-stick",
"stick",
"ball"])
view_3d = get_structure_viz2(atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
st.components.v1.html(view_3d._make_html(), width=400, height=400)
st.markdown("### Structure Information")
atoms_info = {
"Number of Atoms": len(atoms),
"Chemical Formula": atoms.get_chemical_formula(),
"Periodic Boundary Conditions (PBC)": atoms.pbc.tolist(),
"Cell Dimensions": np.round(atoms.cell.cellpar(),3).tolist() if atoms.pbc.any() and atoms.cell is not None and atoms.cell.any() else "No cell / Non-periodic",
"Atom Types": ", ".join(sorted(list(set(atoms.get_chemical_symbols()))))
}
for key, value in atoms_info.items():
st.write(f"**{key}:** {value}")
with col2:
st.markdown('## Calculation Setup', unsafe_allow_html=True)
st.markdown("### Selected Model")
st.write(f"**Model Type:** {model_type}")
st.write(f"**Model:** {selected_model}")
if model_type == "FairChem" and "UMA Small" in selected_model:
st.write(f"**UMA Task Type:** {selected_task_type}")
if model_type == "MACE":
st.write(f"**Dispersion:** {dispersion}")
st.write(f"**Device:** {device}")
st.markdown("### Selected Task")
st.write(f"**Task:** {task}")
if "Geometry Optimization" in task:
st.write(f"**Max Steps:** {max_steps}")
st.write(f"**Convergence Threshold:** {fmax} eV/Å")
st.write(f"**Optimizer:** {optimizer_type}")
run_calculation = st.button("Run Calculation", type="primary")
if run_calculation:
# Delete all the items in Session state
for key in st.session_state.keys():
del st.session_state[key]
results = {}
#global table_placeholder # Ensure they are accessible
opt_log = [] # Reset log for each run
if "Optimization" in task:
table_placeholder = st.empty() # Recreate placeholder for table
try:
torch.set_default_dtype(torch.float32)
with st.spinner("Running calculation... Please wait."):
calc_atoms = atoms.copy()
if model_type == "MACE":
# st.write("Setting up MACE calculator...")
calc = get_mace_model(model_path, dispersion, device, 'float32')
elif model_type == "FairChem": # FairChem
# st.write("Setting up FairChem calculator...")
# Workaround for potential dtype issues when switching models
# if device == "cpu": # Ensure torch default dtype matches if needed
# torch.set_default_dtype(torch.float32)
# _ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
elif model_type == "ORB":
# st.write("Setting up ORB calculator...")
# orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
orbff = model_path(device=device, precision=selected_default_dtype)
calc = ORBCalculator(orbff, device=device)
elif model_type == "MatterSim":
# st.write("Setting up MatterSim calculator...")
# NOTE: Running mattersim on windows requires changing source code file
# https://github.com/microsoft/mattersim/issues/112
# mattersim/datasets/utils/convertor.py: 117
# to pbc_ = np.array(structure.pbc, dtype=np.int64)
calc = MatterSimCalculator(load_path=model_path, device=device)
elif model_type == "SEVEN_NET":
# st.write("Setting up SEVENNET calculator...")
if model_path=='7net-mf-ompa':
calc = SevenNetCalculator(model=model_path, modal=selected_modal_7net, device=device)
else:
calc = SevenNetCalculator(model=model_path, device=device)
calc_atoms.calc = calc
if task == "Energy Calculation":
energy = calc_atoms.get_potential_energy()
results["Energy"] = f"{energy:.6f} eV"
elif task == "Energy + Forces Calculation":
energy = calc_atoms.get_potential_energy()
forces = calc_atoms.get_forces()
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0
results["Energy"] = f"{energy:.6f} eV"
results["Maximum Force"] = f"{max_force:.6f} eV/Å"
elif task == "Atomization/Cohesive Energy":
st.write("Calculating system energy...")
E_system = calc_atoms.get_potential_energy()
num_atoms = len(calc_atoms)
if num_atoms == 0:
st.error("Cannot calculate atomization/cohesive energy for a system with zero atoms.")
results["Error"] = "System has no atoms."
else:
atomic_numbers = calc_atoms.get_atomic_numbers()
E_isolated_atoms_total = 0.0
calculation_possible = True
if model_type == "FairChem":
st.write("Fetching FairChem reference energies for isolated atoms...")
ref_key_suffix = "_elem_refs"
chosen_ref_list_name = None
if "UMA Small" in selected_model:
if selected_task_type:
chosen_ref_list_name = selected_task_type + ref_key_suffix
elif "ESEN" in selected_model:
chosen_ref_list_name = "omol" + ref_key_suffix
if chosen_ref_list_name and chosen_ref_list_name in ELEMENT_REF_ENERGIES:
ref_energies = ELEMENT_REF_ENERGIES[chosen_ref_list_name]
missing_Z_refs = []
for Z_val in atomic_numbers:
if Z_val > 0 and Z_val < len(ref_energies):
E_isolated_atoms_total += ref_energies[Z_val]
else:
if Z_val not in missing_Z_refs: missing_Z_refs.append(Z_val)
if missing_Z_refs:
st.warning(f"Reference energy for atomic number(s) {sorted(list(set(missing_Z_refs)))} "
f"not found in '{chosen_ref_list_name}' list (max Z defined: {len(ref_energies)-1}). "
"These atoms are treated as having 0 reference energy.")
else:
st.error(f"Could not find or determine reference energy list for FairChem model: '{selected_model}' "
f"and UMA task type: '{selected_task_type}'. Cannot calculate atomization/cohesive energy.")
results["Error"] = "Missing FairChem reference energies."
calculation_possible = False
else:# == "MACE":
st.write("Calculating isolated atom energies with MACE...")
unique_atomic_numbers = sorted(list(set(atomic_numbers)))
atom_counts = {Z_unique: np.count_nonzero(atomic_numbers == Z_unique) for Z_unique in unique_atomic_numbers}
progress_text = "Calculating isolated atom energies: 0% complete"
mace_progress_bar = st.progress(0, text=progress_text)
for i, Z_unique in enumerate(unique_atomic_numbers):
isolated_atom = Atoms(numbers=[Z_unique], cell=[20, 20, 20], pbc=False)
if not hasattr(isolated_atom, 'info'): isolated_atom.info = {}
isolated_atom.info["charge"] = 0
isolated_atom.info["spin"] = 0
isolated_atom.calc = calc # Use the same MACE calculator
E_isolated_atom_type = isolated_atom.get_potential_energy()
E_isolated_atoms_total += E_isolated_atom_type * atom_counts[Z_unique]
progress_val = (i + 1) / len(unique_atomic_numbers)
mace_progress_bar.progress(progress_val, text=f"Calculating isolated atom energies for Z={Z_unique}: {int(progress_val*100)}% complete")
mace_progress_bar.empty()
if calculation_possible:
is_periodic = any(calc_atoms.pbc)
if is_periodic:
cohesive_E = (E_isolated_atoms_total - E_system) / num_atoms
results["Cohesive Energy"] = f"{cohesive_E:.6f} eV/atom"
else:
atomization_E = E_isolated_atoms_total - E_system
results["Atomization Energy"] = f"{atomization_E:.6f} eV"
results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV"
results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
elif "Geometry Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
is_periodic = any(calc_atoms.pbc)
opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
# Create temporary trajectory file
traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
if optimizer_type == "BFGS":
opt = BFGS(opt_atoms_obj, trajectory=traj_filename)
elif optimizer_type == "LBFGS":
opt = LBFGS(opt_atoms_obj, trajectory=traj_filename)
else: # FIRE
opt = FIRE(opt_atoms_obj, trajectory=traj_filename)
# opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly
opt.attach(lambda: streamlit_log(opt), interval=1)
st.write(f"Running {task.lower()}...")
opt.run(fmax=fmax, steps=max_steps)
energy = calc_atoms.get_potential_energy()
forces = calc_atoms.get_forces()
max_force = np.max(np.sqrt(np.sum(forces**2, axis=1))) if forces.shape[0] > 0 else 0.0
results["Final Energy"] = f"{energy:.6f} eV"
results["Final Maximum Force"] = f"{max_force:.6f} eV/Å"
results["Steps Taken"] = opt.get_number_of_steps()
results["Converged"] = "Yes" if opt.converged() else "No"
if task == "Cell + Geometry Optimization":
results["Final Cell Parameters"] = np.round(calc_atoms.cell.cellpar(), 4).tolist()
st.success("Calculation completed successfully!")
st.markdown("### Results")
for key, value in results.items():
st.write(f"**{key}:** {value}")
if "Optimization" in task and "Final Energy" in results: # Check if opt was successful
st.markdown("### Optimized Structure")
opt_view = get_structure_viz2(opt_atoms_obj, style=viz_style, show_unit_cell=True, width=400, height=400)
st.components.v1.html(opt_view._make_html(), width=400, height=400)
with tempfile.NamedTemporaryFile(delete=False, suffix=".xyz", mode="w+") as tmp_file_opt:
if is_periodic:
write(tmp_file_opt.name, calc_atoms, format="extxyz")
else:
write(tmp_file_opt.name, calc_atoms, format="xyz")
tmp_filepath_opt = tmp_file_opt.name
with open(tmp_filepath_opt, 'r') as file_opt:
xyz_content_opt = file_opt.read()
@st.fragment
def show_optimized_structure_download_button():
# st.button("Release the balloons", help="Fragment rerun")
# st.balloons()
st.download_button(
label="Download Optimized Structure (XYZ)",
data=xyz_content_opt,
file_name="optimized_structure.xyz",
mime="chemical/x-xyz"
)
show_optimized_structure_download_button()
os.unlink(tmp_filepath_opt)
@st.fragment
def show_trajectory_and_controls():
from ase.io import read
import py3Dmol
if "traj_frames" not in st.session_state:
if os.path.exists(traj_filename):
try:
trajectory = read(traj_filename, index=":")
st.session_state.traj_frames = trajectory
st.session_state.traj_index = 0
except Exception as e:
st.error(f"Error reading trajectory: {e}")
return
# finally:
# os.unlink(traj_filename)
else:
st.warning("Trajectory file not found.")
return
trajectory = st.session_state.traj_frames
index = st.session_state.traj_index
st.markdown("### Optimization Trajectory")
st.write(f"Captured {len(trajectory)} optimization steps")
# Navigation Buttons
col1, col2, col3, col4 = st.columns(4)
with col1:
if st.button("⏮ First"):
st.session_state.traj_index = 0
with col2:
if st.button("◀ Previous") and index > 0:
st.session_state.traj_index -= 1
with col3:
if st.button("Next ▶") and index < len(trajectory) - 1:
st.session_state.traj_index += 1
with col4:
if st.button("Last ⏭"):
st.session_state.traj_index = len(trajectory) - 1
# Show current frame
current_atoms = trajectory[st.session_state.traj_index]
st.write(f"Frame {st.session_state.traj_index + 1}/{len(trajectory)}")
def atoms_to_xyz_string(atoms, step_idx=None):
xyz_str = f"{len(atoms)}\n"
if step_idx is not None:
xyz_str += f"Step {step_idx}, Energy = {atoms.get_potential_energy():.6f} eV\n"
else:
xyz_str += f"Energy = {atoms.get_potential_energy():.6f} eV\n"
for atom in atoms:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
return xyz_str
traj_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
st.components.v1.html(traj_view._make_html(), width=400, height=400)
# Download button for entire trajectory
trajectory_xyz = ""
for i, atoms in enumerate(trajectory):
trajectory_xyz += atoms_to_xyz_string(atoms, i)
st.download_button(
label="Download Optimization Trajectory (XYZ)",
data=trajectory_xyz,
file_name="optimization_trajectory.xyz",
mime="chemical/x-xyz"
)
show_trajectory_and_controls()
elif task == "Global Optimization":
st.info(f"Starting Global Optimization using {global_method}...")
# Create temporary trajectory file to store the "hopping" steps
traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
# Container for live updates
log_container = st.empty()
global_min_energy = 0
def global_log(opt_instance):
"""Helper to log global optimization steps."""
global global_min_energy
current_e = opt_instance.atoms.get_potential_energy()
# For BasinHopping, nsteps is available. For others, we might need a counter.
step = getattr(opt_instance, 'nsteps', 'N/A')
log_container.write(f"Global Step: {step} | Energy: {current_e:.6f} eV")
if current_e < global_min_energy:
global_min_energy = current_e
if global_method == "Basin Hopping":
# Basin Hopping requires Temperature in eV (kB * T)
kT = temperature_K * kB
# Create the wrapper for the hack needed to enforce the optimization to stop when it reaches a certain number of steps
class LimitedLBFGS(LBFGS):
def run(self, fmax=0.05, steps=None):
# 'steps' here overrides whatever BasinHopping tries to do.
# Set your desired max local steps (e.g., 200)
return super().run(fmax=fmax, steps=100)
# Initialize Basin Hopping with the trajectory file
bh = BasinHopping(calc_atoms,
temperature=kT,
dr=dr_amp,
optimizer=LimitedLBFGS,
fmax=fmax_local,
trajectory=traj_filename) # Log steps to file automatically
# Attach the live logger
bh.attach(lambda: global_log(bh), interval=1)
# Run the optimization
bh.run(global_steps)
results["Global Minimum Energy"] = f"{global_min_energy:.6f} eV"
results["Steps Taken"] = global_steps
results["Converged"] = "N/A (Global Search)"
elif global_method == "Minima Hopping":
# Minima Hopping manages its own internal optimizers and doesn't accept a 'trajectory'
# file argument in the same way BasinHopping does in __init__.
opt = MinimaHopping(calc_atoms,
T0=temperature_K,
fmax=fmax_local,
optimizer=LBFGS)
# We run it. Live logging is harder here without subclassing,
# so we rely on the final output for the trajectory.
opt(totalsteps=global_steps)
results["Current Energy"] = f"{calc_atoms.get_potential_energy():.6f} eV"
# Post-processing: MinimaHopping stores visited minima in an internal list usually.
# We explicitly write the found minima to the trajectory file so the visualizer below works.
# Note: opt.minima is a list of Atoms objects found during the hop.
if hasattr(opt, 'minima'):
from ase.io import write
write(traj_filename, opt.minima)
else:
# Fallback if specific version doesn't store list, just save final
write(traj_filename, calc_atoms)
st.success("Global Optimization Complete!")
st.markdown("### Results")
for key, value in results.items():
st.write(f"**{key}:** {value}")
# --- Visualization and Downloading (Fragmented) ---
# 1. Clean up the temp file path for reading
# We define the visualizer function using @st.fragment to prevent full re-runs
@st.fragment
def show_global_trajectory_and_dl():
from ase.io import read
import py3Dmol
# Helper to convert atoms list to XYZ string for the single download file
def atoms_list_to_xyz_string(atoms_list):
xyz_str = ""
for i, atoms in enumerate(atoms_list):
xyz_str += f"{len(atoms)}\n"
xyz_str += f"Step {i}, Energy = {atoms.get_potential_energy():.6f} eV\n"
for atom in atoms:
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
return xyz_str
if "global_traj_frames" not in st.session_state:
if os.path.exists(traj_filename):
try:
# Read the trajectory we just created
trajectory = read(traj_filename, index=":")
st.session_state.global_traj_frames = trajectory
st.session_state.global_traj_index = 0
except Exception as e:
st.error(f"Error reading trajectory: {e}")
return
else:
st.warning("Trajectory file not generated.")
return
trajectory = st.session_state.global_traj_frames
if not trajectory:
st.warning("No steps recorded in trajectory.")
return
index = st.session_state.global_traj_index
st.markdown("### Global Search Trajectory")
st.write(f"Captured {len(trajectory)} hopping steps (Local Minima)")
# Navigation Controls
col1, col2, col3, col4 = st.columns(4)
with col1:
if st.button("⏮ First", key="g_first"):
st.session_state.global_traj_index = 0
with col2:
if st.button("◀ Previous", key="g_prev") and index > 0:
st.session_state.global_traj_index -= 1
with col3:
if st.button("Next ▶", key="g_next") and index < len(trajectory) - 1:
st.session_state.global_traj_index += 1
with col4:
if st.button("Last ⏭", key="g_last"):
st.session_state.global_traj_index = len(trajectory) - 1
# Display Visualization
current_atoms = trajectory[st.session_state.global_traj_index]
st.write(f"Frame {st.session_state.global_traj_index + 1}/{len(trajectory)} | E = {current_atoms.get_potential_energy():.4f} eV")
viz_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
st.components.v1.html(viz_view._make_html(), width=400, height=400)
# Download Logic
full_xyz_content = atoms_list_to_xyz_string(trajectory)
st.download_button(
label="Download Trajectory (XYZ)",
data=full_xyz_content,
file_name="global_optimization_path.xyz",
mime="chemical/x-xyz"
)
# Separate Download for just the Best Structure (Last frame usually in BH, or sorted)
# Often in BH, the last frame is the accepted state, but not necessarily the global min seen *ever*.
# But usually, we want the lowest energy one.
energies = [a.get_potential_energy() for a in trajectory]
best_idx = np.argmin(energies)
best_atoms = trajectory[best_idx]
# Create XYZ for single best
with tempfile.NamedTemporaryFile(mode='w', suffix=".xyz", delete=False) as tmp_best:
write(tmp_best.name, best_atoms)
tmp_best_name = tmp_best.name
with open(tmp_best_name, "r") as f:
st.download_button(
label=f"Download Best Structure (E={energies[best_idx]:.4f} eV)",
data=f.read(),
file_name="best_global_structure.xyz",
mime="chemical/x-xyz"
)
os.unlink(tmp_best_name)
# Call the fragment function
show_global_trajectory_and_dl()
# Cleanup main trajectory file after loading it into session state if desired,
# though keeping it until session end is safer for re-reads.
# os.unlink(traj_filename)
elif task == "Vibrational Mode Analysis":
# Conversion factors
from ase.units import kB as kB_eVK, _Nav, J # ASE's constants
from scipy.constants import physical_constants
kB_JK = physical_constants["Boltzmann constant"][0] # J/K
is_periodic = any(calc_atoms.pbc)
st.write("Running vibrational mode analysis using finite differences...")
natoms = len(calc_atoms)
is_linear = False # Set manually or auto-detect
nmodes_expected = 3 * natoms - (5 if is_linear else 6)
# Create temporary directory to store .vib files
with tempfile.TemporaryDirectory() as tmpdir:
vib = Vibrations(calc_atoms, name=os.path.join(tmpdir, 'vib'))
with st.spinner("Calculating vibrational modes... This may take a few minutes."):
vib.run()
freqs = vib.get_frequencies()
energies = vib.get_energies()
print('\n\n\n\n\n\n\n\n')
# vib.get_hessian_2d()
# st.write(vib.summary())
# print('\n')
# vib.tabulate()
freqs_cm = freqs
freqs_eV = energies
# Classify frequencies
mode_data = []
for i, freq in enumerate(freqs_cm):
if freq < 0:
label = "Imaginary"
elif abs(freq) < 500:
label = "Low"
else:
label = "Physical"
mode_data.append({
"Mode": i + 1,
"Frequency (cm⁻¹)": round(freq, 2),
"Type": label
})
df_modes = pd.DataFrame(mode_data)
# Display summary and mode count
st.success("Vibrational analysis completed.")
st.write(f"Number of atoms: {natoms}")
st.write(f"Expected vibrational modes: {nmodes_expected}")
st.write(f"Found {len(freqs_cm)} modes (including translational/rotational modes).")
# Show table of modes
st.write("### Vibrational Mode Summary")
st.dataframe(df_modes, use_container_width=True)
# Store in results dictionary
results["Vibrational Modes"] = df_modes.to_dict(orient="records")
# Histogram plot of vibrational frequencies
st.write("### Frequency Distribution Histogram")
fig, ax = plt.subplots()
ax.hist(freqs_cm, bins=30, color='skyblue', edgecolor='black')
ax.set_xlabel("Frequency (cm⁻¹)")
ax.set_ylabel("Number of Modes")
ax.set_title("Distribution of Vibrational Frequencies")
st.pyplot(fig)
# CSV download
csv_buffer = io.StringIO()
df_modes.to_csv(csv_buffer, index=False)
st.download_button(
label="Download Vibrational Frequencies (CSV)",
data=csv_buffer.getvalue(),
file_name="vibrational_modes.csv",
mime="text/csv"
)
# -------- Thermodynamic Analysis for Molecules --------
if not is_periodic:
# Filter physical frequencies > 1 cm⁻¹ (to avoid numerical issues)
physical_freqs_eV = np.array([f for f in freqs_eV if f > 1e-5])
# Zero-point vibrational energy (ZPE)
ZPE = 0.5 * np.sum(physical_freqs_eV) # in eV
# Vibrational entropy (in eV/K)
vib_entropy = 0.0
for f in physical_freqs_eV:
x = f / (kB_eVK * T)
vib_entropy += (x / (np.exp(x) - 1) - np.log(1 - np.exp(-x)))
S_vib_eVK = kB_eVK * vib_entropy # eV/K
S_vib_JmolK = S_vib_eVK * J * _Nav # J/mol·K
results["ZPE (eV)"] = ZPE.real
results["Vibrational Entropy (eV/K)"] = S_vib_eVK
results["Vibrational Entropy (J/mol·K)"] = S_vib_JmolK
st.write(f"**Zero-point vibrational energy (ZPE)**: {ZPE.real:.6f} eV")
st.write(f"**Vibrational entropy**: {S_vib_eVK:.6f} eV/K")
else:
st.info("Thermodynamic properties like ZPE and entropy are currently only meaningful for isolated molecules (non-periodic systems).")
elif task == "Phonons":
from ase.phonons import Phonons
st.write("### Phonon Band Structure and Density of States")
is_periodic = any(calc_atoms.pbc)
if not is_periodic:
st.error("Phonon calculations require a periodic structure. Please use a periodic system.")
else:
with tempfile.TemporaryDirectory() as tmpdir:
st.info("Running phonon calculation using finite displacements...")
sc = (7, 7, 7)
# Create phonon object
ph = Phonons(calc_atoms, calc_atoms.calc, supercell=sc, delta=0.001, name=os.path.join(tmpdir, 'phonon'))
with st.spinner("Displacing atoms and computing forces..."):
ph.run()
# Build dynamical matrix
ph.read(acoustic=True)
ph.clean()
# Band path and DOS
# path = calc_atoms.cell.bandpath('GXULGK', npoints=100)
path = calc_atoms.cell.bandpath('GXKGL', npoints=100)
# path = calc_atoms.cell.bandpath(eps=0.00001)
bs = ph.get_band_structure(path)
dos = ph.get_dos(kpts=(20, 20, 20)).sample_grid(npts=100, width=1e-3)
# Plotting
fig = plt.figure(figsize=(7, 4))
ax = fig.add_axes([0.12, 0.07, 0.67, 0.85])
emax = 0.075
bs.plot(ax=ax, emin=0.0, emax=emax)
dosax = fig.add_axes([0.8, 0.07, 0.17, 0.85])
dosax.fill_between(
dos.get_weights(),
dos.get_energies(),
y2=0,
color='grey',
edgecolor='k',
lw=1,
)
dosax.set_ylim(0, emax)
dosax.set_yticks([])
dosax.set_xticks([])
dosax.set_xlabel('DOS', fontsize=14)
st.pyplot(fig)
st.success("Phonon band structure and DOS successfully plotted.")
except Exception as e:
st.error(f"🔴 Calculation error: {str(e)}")
st.error("Please check the structure, model compatibility, and parameters. For FairChem UMA, ensure the task type (omol, omat etc.) is appropriate for your system (e.g. omol for molecules, omat for materials).")
import traceback
st.error(f"Traceback: {traceback.format_exc()}")
else:
st.info("👋 Welcome! Please select or upload a structure using the sidebar options to begin.")
st.markdown("---")
with st.expander('ℹ️ About This App & Foundational MLIPs'):
st.write("""
**Test, compare, and benchmark universal machine learning interatomic potentials (MLIPs).**
This application allows you to perform atomistic simulations using pre-trained foundational MLIPs
from the MACE and FairChem (by Meta AI) libraries.
**Features:**
- Upload structure files (XYZ, CIF, POSCAR, etc.) or use built-in examples.
- Select from various MACE and FairChem models.
- Calculate energies, forces, and perform geometry/cell optimizations.
- **New**: Calculate atomization energy (for molecules) or cohesive energy (for periodic systems).
- Visualize atomic structures in 3D and download results.
**Quick Start:**
1. **Input**: Choose an input method in the sidebar (e.g., "Select Example").
2. **Model**: Pick a model type (MACE/FairChem) and specific model. For FairChem UMA, select the appropriate task type (e.g., `omol` for molecules, `omat` for materials).
3. **Task**: Select a calculation task (e.g., "Energy Calculation", "Atomization/Cohesive Energy", "Geometry Optimization").
4. **Run**: Click "Run Calculation" and view the results.
**Atomization/Cohesive Energy Notes:**
- **Atomization Energy** ($E_{\\text{atomization}} = \sum E_{\\text{isolated atoms}} - E_{\\text{molecule}}$) is typically for non-periodic systems (molecules).
- **Cohesive Energy** ($E_{\\text{cohesive}} = (\sum E_{\\text{isolated atoms}} - E_{\\text{bulk system}}) / N_{\\text{atoms}}$) is for periodic systems.
- For **MACE models**, isolated atom energies are computed on-the-fly.
- For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
""")
with st.expander('🔧 Tech Stack & System Information'):
import platform
import psutil
st.markdown("### System Information")
col1, col2 = st.columns(2)
with col1:
st.write("**Operating System:**")
st.write(f"- OS: {platform.system()} {platform.release()}")
st.write(f"- Version: {platform.version()}")
st.write(f"- Architecture: {platform.machine()}")
st.write(f"- Processor: {platform.processor()}")
st.write("\n**Python Environment:**")
st.write(f"- Python Version: {platform.python_version()}")
st.write(f"- Python Implementation: {platform.python_implementation()}")
with col2:
st.write("**Hardware Resources:**")
st.write(f"- CPU Cores: {psutil.cpu_count(logical=False)} physical, {psutil.cpu_count(logical=True)} logical")
st.write(f"- CPU Usage: {psutil.cpu_percent(interval=1)}%")
memory = psutil.virtual_memory()
st.write(f"- Total RAM: {memory.total / (1024**3):.2f} GB")
st.write(f"- Available RAM: {memory.available / (1024**3):.2f} GB")
st.write(f"- RAM Usage: {memory.percent}%")
disk = psutil.disk_usage('/')
st.write(f"- Total Disk Space: {disk.total / (1024**3):.2f} GB")
st.write(f"- Free Disk Space: {disk.free / (1024**3):.2f} GB")
st.write(f"- Disk Usage: {disk.percent}%")
st.markdown("### Package Versions")
packages_to_check = [
'streamlit', 'torch', 'numpy', 'ase', 'py3Dmol',
'mace-torch', 'fairchem-core', 'orb-models', 'sevenn',
'pandas', 'matplotlib', 'scipy', 'yaml', 'huggingface-hub'
]
if mattersim_available:
packages_to_check.append('mattersim')
package_versions = {}
for package in packages_to_check:
try:
version = pkg_resources.get_distribution(package).version
package_versions[package] = version
except pkg_resources.DistributionNotFound:
package_versions[package] = "Not installed"
# Display in two columns
col1, col2 = st.columns(2)
items = list(package_versions.items())
mid_point = len(items) // 2
with col1:
for package, version in items[:mid_point]:
st.write(f"**{package}:** {version}")
with col2:
for package, version in items[mid_point:]:
st.write(f"**{package}:** {version}")
# PyTorch specific information
st.markdown("### PyTorch Configuration")
st.write(f"**PyTorch Version:** {torch.__version__}")
st.write(f"**CUDA Available:** {torch.cuda.is_available()}")
if torch.cuda.is_available():
st.write(f"**CUDA Version:** {torch.version.cuda}")
st.write(f"**cuDNN Version:** {torch.backends.cudnn.version()}")
st.write(f"**Number of GPUs:** {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
st.write(f"**GPU {i}:** {torch.cuda.get_device_name(i)}")
else:
st.write("Running on CPU only")
st.markdown("---")
st.markdown("Universal MLIP Playground App | Created with Streamlit, ASE, MACE, FairChem, SevenNet, ORB and ❤️")
st.markdown("Developed by [Manas Sharma](https://manas.bragitoff.com/) in the groups of [Prof. Ananth Govind Rajan Group](https://www.agrgroup.org/) and [Prof. Sudeep Punnathanam](https://chemeng.iisc.ac.in/sudeep/) at [IISc Bangalore](https://iisc.ac.in/)")