Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +16 -3
src/streamlit_app.py
CHANGED
|
@@ -14,6 +14,7 @@ from mace.calculators import mace_mp
|
|
| 14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
| 15 |
from orb_models.forcefield import pretrained
|
| 16 |
from orb_models.forcefield.calculator import ORBCalculator
|
|
|
|
| 17 |
import pandas as pd
|
| 18 |
import yaml # Added for FairChem reference energies
|
| 19 |
import subprocess
|
|
@@ -718,8 +719,14 @@ ORB_MODELS = {
|
|
| 718 |
}
|
| 719 |
# Define the available MatterSim models
|
| 720 |
MATTERSIM_MODELS = {
|
| 721 |
-
"V1 SMALL: MatterSim-v1.0.0-1M.pth",
|
| 722 |
-
"V1 LARGE: MatterSim-v1.0.0-5M.pth"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 723 |
}
|
| 724 |
@st.cache_resource
|
| 725 |
def get_mace_model(model_path, device, selected_default_dtype):
|
|
@@ -792,7 +799,7 @@ st.sidebar.markdown("## Model Selection")
|
|
| 792 |
if mattersim_available:
|
| 793 |
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "MatterSim"])
|
| 794 |
else:
|
| 795 |
-
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB"])
|
| 796 |
|
| 797 |
selected_task_type = None # For FairChem UMA
|
| 798 |
if model_type == "MACE":
|
|
@@ -821,6 +828,9 @@ if model_type == "ORB":
|
|
| 821 |
if model_type == "MatterSim":
|
| 822 |
selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
|
| 823 |
model_path = MATTERSIM_MODELS[selected_model]
|
|
|
|
|
|
|
|
|
|
| 824 |
if atoms is not None:
|
| 825 |
if not check_atom_limit(atoms, selected_model):
|
| 826 |
st.stop() # Stop execution if limit exceeded
|
|
@@ -914,6 +924,9 @@ if atoms is not None:
|
|
| 914 |
elif model_type == "MatterSim":
|
| 915 |
st.write("Setting up MatterSim calculator...")
|
| 916 |
calc = MatterSimCalculator(load_path=model_path, device=device)
|
|
|
|
|
|
|
|
|
|
| 917 |
calc_atoms.calc = calc
|
| 918 |
|
| 919 |
if task == "Energy Calculation":
|
|
|
|
| 14 |
from fairchem.core import pretrained_mlip, FAIRChemCalculator
|
| 15 |
from orb_models.forcefield import pretrained
|
| 16 |
from orb_models.forcefield.calculator import ORBCalculator
|
| 17 |
+
from sevenn.calculator import SevenNetCalculator
|
| 18 |
import pandas as pd
|
| 19 |
import yaml # Added for FairChem reference energies
|
| 20 |
import subprocess
|
|
|
|
| 719 |
}
|
| 720 |
# Define the available MatterSim models
|
| 721 |
MATTERSIM_MODELS = {
|
| 722 |
+
"V1 SMALL": "MatterSim-v1.0.0-1M.pth",
|
| 723 |
+
"V1 LARGE": "MatterSim-v1.0.0-5M.pth"
|
| 724 |
+
}
|
| 725 |
+
SEVEN_NET_MODELS = {
|
| 726 |
+
"7net-0": "7net-0",
|
| 727 |
+
"7net-l3i5": "7net-l3i5",
|
| 728 |
+
"7net-omat": "7net-omat",
|
| 729 |
+
"7net-mf-ompa": "7net-mf-ompa"
|
| 730 |
}
|
| 731 |
@st.cache_resource
|
| 732 |
def get_mace_model(model_path, device, selected_default_dtype):
|
|
|
|
| 799 |
if mattersim_available:
|
| 800 |
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "MatterSim"])
|
| 801 |
else:
|
| 802 |
+
model_type = st.sidebar.radio("Select Model Type:", ["MACE", "FairChem", "ORB", "SEVEN_NET"])
|
| 803 |
|
| 804 |
selected_task_type = None # For FairChem UMA
|
| 805 |
if model_type == "MACE":
|
|
|
|
| 828 |
if model_type == "MatterSim":
|
| 829 |
selected_model = st.sidebar.selectbox("Select MatterSim Model:", list(MATTERSIM_MODELS.keys()))
|
| 830 |
model_path = MATTERSIM_MODELS[selected_model]
|
| 831 |
+
if model_type == "SEVEN_NET":
|
| 832 |
+
selected_model = st.sidebar.selectbox("Select SEVENNET Model:", list(SEVEN_NET_MODELS.keys()))
|
| 833 |
+
model_path = SEVEN_NET_MODELS[selected_model]
|
| 834 |
if atoms is not None:
|
| 835 |
if not check_atom_limit(atoms, selected_model):
|
| 836 |
st.stop() # Stop execution if limit exceeded
|
|
|
|
| 924 |
elif model_type == "MatterSim":
|
| 925 |
st.write("Setting up MatterSim calculator...")
|
| 926 |
calc = MatterSimCalculator(load_path=model_path, device=device)
|
| 927 |
+
elif model_type == "SEVEN_NET":
|
| 928 |
+
st.write("Setting up SEVENNET calculator...")
|
| 929 |
+
calc = SevenNetD3Calculator(model=model_path, device=device)
|
| 930 |
calc_atoms.calc = calc
|
| 931 |
|
| 932 |
if task == "Energy Calculation":
|