Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +18 -5
src/streamlit_app.py
CHANGED
|
@@ -12,6 +12,8 @@ from ase.visualize import view
|
|
| 12 |
import py3Dmol
|
| 13 |
from mace.calculators import mace_mp
|
| 14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
|
|
|
|
|
|
| 15 |
import pandas as pd
|
| 16 |
import yaml # Added for FairChem reference energies
|
| 17 |
|
|
@@ -673,7 +675,10 @@ FAIRCHEM_MODELS = {
|
|
| 673 |
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
|
| 674 |
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
|
| 675 |
}
|
| 676 |
-
|
|
|
|
|
|
|
|
|
|
| 677 |
@st.cache_resource
|
| 678 |
def get_mace_model(model_path, device, selected_default_dtype):
|
| 679 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
|
@@ -742,7 +747,7 @@ if atoms is not None:
|
|
| 742 |
|
| 743 |
|
| 744 |
st.sidebar.markdown("## Model Selection")
|
| 745 |
-
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem"])
|
| 746 |
|
| 747 |
selected_task_type = None # For FairChem UMA
|
| 748 |
if model_type == "MACE":
|
|
@@ -762,7 +767,12 @@ if model_type == "FairChem":
|
|
| 762 |
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
|
| 763 |
atoms.info["charge"] = charge
|
| 764 |
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
|
| 765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
if atoms is not None:
|
| 767 |
if not check_atom_limit(atoms, selected_model):
|
| 768 |
st.stop() # Stop execution if limit exceeded
|
|
@@ -842,14 +852,17 @@ if atoms is not None:
|
|
| 842 |
if model_type == "MACE":
|
| 843 |
# st.write("Setting up MACE calculator...")
|
| 844 |
calc = get_mace_model(model_path, device, selected_default_dtype)
|
| 845 |
-
|
| 846 |
# st.write("Setting up FairChem calculator...")
|
| 847 |
# Workaround for potential dtype issues when switching models
|
| 848 |
if device == "cpu": # Ensure torch default dtype matches if needed
|
| 849 |
torch.set_default_dtype(torch.float32)
|
| 850 |
_ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
|
| 851 |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
| 852 |
-
|
|
|
|
|
|
|
|
|
|
| 853 |
calc_atoms.calc = calc
|
| 854 |
|
| 855 |
if task == "Energy Calculation":
|
|
|
|
| 12 |
import py3Dmol
|
| 13 |
from mace.calculators import mace_mp
|
| 14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
| 15 |
+
from orb_models.forcefield import pretrained
|
| 16 |
+
from orb_models.forcefield.calculator import ORBCalculator
|
| 17 |
import pandas as pd
|
| 18 |
import yaml # Added for FairChem reference energies
|
| 19 |
|
|
|
|
| 675 |
"ESEN SM Conserving All OMOL": "esen-sm-conserving-all-omol",
|
| 676 |
"ESEN SM Direct All OMOL": "esen-sm-direct-all-omol"
|
| 677 |
}
|
| 678 |
+
# Define the available ORB models
|
| 679 |
+
ORB_MODELS = {
|
| 680 |
+
"V3 OMAT Conserving": "orb_v3_conservative_inf_omat",
|
| 681 |
+
}
|
| 682 |
@st.cache_resource
|
| 683 |
def get_mace_model(model_path, device, selected_default_dtype):
|
| 684 |
return mace_mp(model=model_path, device=device, default_dtype=selected_default_dtype)
|
|
|
|
| 747 |
|
| 748 |
|
| 749 |
st.sidebar.markdown("## Model Selection")
|
| 750 |
+
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB"])
|
| 751 |
|
| 752 |
selected_task_type = None # For FairChem UMA
|
| 753 |
if model_type == "MACE":
|
|
|
|
| 767 |
spin_multiplicity = st.sidebar.number_input("Spin Multiplicity (2S + 1)", min_value=1, max_value=11, step=2, value=int(atoms.info.get("spin",0)*2+1 if atoms.info.get("spin",0) is not None else 1)) # Assuming spin in atoms.info is S
|
| 768 |
atoms.info["charge"] = charge
|
| 769 |
atoms.info["spin"] = spin_multiplicity # FairChem expects multiplicity
|
| 770 |
+
if model_type == "ORB":
|
| 771 |
+
selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
|
| 772 |
+
model_path = ORB_MODELS[selected_model]
|
| 773 |
+
# if "omat" in selected_model:
|
| 774 |
+
# st.sidebar.warning("Using model under Academic Software License (ASL) license, see [https://github.com/gabor1/ASL](https://github.com/gabor1/ASL). To use this model you accept the terms of the license.")
|
| 775 |
+
selected_default_dtype = st.sidebar.selectbox("Select Precision (default_dtype):", ['float32-high', 'float32-highest', 'float64'])
|
| 776 |
if atoms is not None:
|
| 777 |
if not check_atom_limit(atoms, selected_model):
|
| 778 |
st.stop() # Stop execution if limit exceeded
|
|
|
|
| 852 |
if model_type == "MACE":
|
| 853 |
# st.write("Setting up MACE calculator...")
|
| 854 |
calc = get_mace_model(model_path, device, selected_default_dtype)
|
| 855 |
+
elif model_type == "FairChem": # FairChem
|
| 856 |
# st.write("Setting up FairChem calculator...")
|
| 857 |
# Workaround for potential dtype issues when switching models
|
| 858 |
if device == "cpu": # Ensure torch default dtype matches if needed
|
| 859 |
torch.set_default_dtype(torch.float32)
|
| 860 |
_ = get_mace_model(MACE_MODELS["MACE MP 0a Small"], 'cpu', 'float32') # Dummy call
|
| 861 |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
| 862 |
+
elif model_type == "ORB":
|
| 863 |
+
st.write("Setting up ORB calculator...")
|
| 864 |
+
orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
|
| 865 |
+
calc = ORBCalculator(orbff, device=device)
|
| 866 |
calc_atoms.calc = calc
|
| 867 |
|
| 868 |
if task == "Energy Calculation":
|