Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +12 -0
src/streamlit_app.py
CHANGED
|
@@ -14,6 +14,7 @@ from mace.calculators import mace_mp
|
|
| 14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
| 15 |
from orb_models.forcefield import pretrained
|
| 16 |
from orb_models.forcefield.calculator import ORBCalculator
|
|
|
|
| 17 |
import pandas as pd
|
| 18 |
import yaml # Added for FairChem reference energies
|
| 19 |
|
|
@@ -686,6 +687,11 @@ ORB_MODELS = {
|
|
| 686 |
"V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa",
|
| 687 |
"V3 MPA Direct (20)": "orb-v3-direct-20-mpa",
|
| 688 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
@st.cache_resource
|
| 690 |
def get_mace_model(model_path, device, selected_default_dtype):
|
| 691 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
|
@@ -780,6 +786,9 @@ if model_type == "ORB":
|
|
| 780 |
# if "omat" in selected_model:
|
| 781 |
# st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
|
| 782 |
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
|
|
|
|
|
|
|
|
|
|
| 783 |
if atoms is not None:
|
| 784 |
if not check_atom_limit(atoms, selected_model):
|
| 785 |
st.stop() # Stop execution if limit exceeded
|
|
@@ -870,6 +879,9 @@ if atoms is not None:
|
|
| 870 |
st.write("Setting up ORB calculator...")
|
| 871 |
orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
|
| 872 |
calc = ORBCalculator(orbff, device=device)
|
|
|
|
|
|
|
|
|
|
| 873 |
calc_atoms.calc = calc
|
| 874 |
|
| 875 |
if task == "Energy Calculation":
|
|
|
|
| 14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
| 15 |
from orb_models.forcefield import pretrained
|
| 16 |
from orb_models.forcefield.calculator import ORBCalculator
|
| 17 |
+
from mattersim.forcefield import MatterSimCalculator
|
| 18 |
import pandas as pd
|
| 19 |
import yaml # Added for FairChem reference energies
|
| 20 |
|
|
|
|
| 687 |
"V3 MPA Direct (inf)": "orb-v3-direct-inf-mpa",
|
| 688 |
"V3 MPA Direct (20)": "orb-v3-direct-20-mpa",
|
| 689 |
}
|
| 690 |
+
# Define the available MatterSim models
|
| 691 |
+
MATTERSIM_MODELS = {
|
| 692 |
+
"V1 SMALL: MatterSim-v1.0.0-1M.pth",
|
| 693 |
+
"V1 LARGE: MatterSim-v1.0.0-5M.pth"
|
| 694 |
+
}
|
| 695 |
@st.cache_resource
|
| 696 |
def get_mace_model(model_path, device, selected_default_dtype):
|
| 697 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
|
|
|
| 786 |
# if "omat" in selected_model:
|
| 787 |
# st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
|
| 788 |
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
|
| 789 |
+
if model_type == "MatterSim":
|
| 790 |
+
selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
|
| 791 |
+
model_path = MATTERSIM_MODELS[selected_model]
|
| 792 |
if atoms is not None:
|
| 793 |
if not check_atom_limit(atoms, selected_model):
|
| 794 |
st.stop() # Stop execution if limit exceeded
|
|
|
|
| 879 |
st.write("Setting up ORB calculator...")
|
| 880 |
orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
|
| 881 |
calc = ORBCalculator(orbff, device=device)
|
| 882 |
+
elif model_type == "MatterSim":
|
| 883 |
+
st.write("Setting up MatterSim calculator...")
|
| 884 |
+
calc = MatterSimCalculator(load_path=model_path, device=device)
|
| 885 |
calc_atoms.calc = calc
|
| 886 |
|
| 887 |
if task == "Energy Calculation":
|