Spaces:
Running
Running
add orbmol model and fix issues with orb models
Browse files- src/streamlit_app.py +150 -120
src/streamlit_app.py
CHANGED
|
@@ -30,8 +30,11 @@ import sys
|
|
| 30 |
import pkg_resources
|
| 31 |
from ase.vibrations import Vibrations
|
| 32 |
from mp_api.client import MPRester
|
|
|
|
|
|
|
| 33 |
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
| 34 |
from pymatgen.io.ase import AseAtomsAdaptor
|
|
|
|
| 35 |
import matplotlib.pyplot as plt
|
| 36 |
mattersim_available = True
|
| 37 |
if mattersim_available:
|
|
@@ -926,14 +929,16 @@ FAIRCHEM_MODELS = {
|
|
| 926 |
}
|
| 927 |
# Define the available ORB models
|
| 928 |
ORB_MODELS = {
|
| 929 |
-
"V3
|
| 930 |
-
"V3
|
| 931 |
-
"V3 OMAT
|
| 932 |
-
"V3 OMAT
|
| 933 |
-
"V3
|
| 934 |
-
"V3
|
| 935 |
-
"V3 MPA
|
| 936 |
-
"V3 MPA
|
|
|
|
|
|
|
| 937 |
}
|
| 938 |
# Define the available MatterSim models
|
| 939 |
MATTERSIM_MODELS = {
|
|
@@ -959,117 +964,11 @@ def get_fairchem_model(selected_model_name, model_path_or_name, device, selected
|
|
| 959 |
calc = FAIRChemCalculator(predictor, task_name="omol")
|
| 960 |
return calc
|
| 961 |
|
| 962 |
-
# st.sidebar.markdown("## Input Options")
|
| 963 |
-
# input_method = st.sidebar.radio("Choose Input Method:",
|
| 964 |
-
# ["Select Example", "Upload File", "Paste Content", "Materials Project ID"])
|
| 965 |
-
# atoms = None
|
| 966 |
-
|
| 967 |
-
# if input_method == "Upload File":
|
| 968 |
-
# uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
|
| 969 |
-
# if uploaded_file:
|
| 970 |
-
# with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
| 971 |
-
# tmp_file.write(uploaded_file.getvalue())
|
| 972 |
-
# tmp_filepath = tmp_file.name
|
| 973 |
-
# try:
|
| 974 |
-
# atoms = read(tmp_filepath)
|
| 975 |
-
# st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
|
| 976 |
-
# except Exception as e:
|
| 977 |
-
# st.sidebar.error(f"Error loading file: {str(e)}")
|
| 978 |
-
# finally:
|
| 979 |
-
# if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
| 980 |
-
# os.unlink(tmp_filepath)
|
| 981 |
-
|
| 982 |
-
# elif input_method == "Select Example":
|
| 983 |
-
# example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
|
| 984 |
-
# if example_name:
|
| 985 |
-
# file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
|
| 986 |
-
# try:
|
| 987 |
-
# atoms = read(file_path)
|
| 988 |
-
# st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!")
|
| 989 |
-
# except Exception as e:
|
| 990 |
-
# st.sidebar.error(f"Error loading example: {str(e)}")
|
| 991 |
-
|
| 992 |
-
# elif input_method == "Paste Content":
|
| 993 |
-
# file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
|
| 994 |
-
# content = st.sidebar.text_area("Paste file content here:", height=200)
|
| 995 |
-
# if content:
|
| 996 |
-
# try:
|
| 997 |
-
# suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
|
| 998 |
-
# suffix = suffix_map.get(file_format, ".xyz")
|
| 999 |
-
# with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
| 1000 |
-
# tmp_file.write(content.encode())
|
| 1001 |
-
# tmp_filepath = tmp_file.name
|
| 1002 |
-
# atoms = read(tmp_filepath)
|
| 1003 |
-
# st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
|
| 1004 |
-
# except Exception as e:
|
| 1005 |
-
# st.sidebar.error(f"Error parsing content: {str(e)}")
|
| 1006 |
-
# finally:
|
| 1007 |
-
# if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
| 1008 |
-
# os.unlink(tmp_filepath)
|
| 1009 |
-
# elif input_method == "Materials Project ID":
|
| 1010 |
-
# # st.sidebar.info("Requires an API Key from [Materials Project](https://next-gen.materialsproject.org/api)")
|
| 1011 |
-
|
| 1012 |
-
# # 1. Get API Key (You can also set this as an env variable)
|
| 1013 |
-
# # mp_api_key = st.sidebar.text_input("MP API Key", type="password")
|
| 1014 |
-
# mp_api_key = os.getenv("MP_API_KEY")
|
| 1015 |
-
# # 2. Get Material ID
|
| 1016 |
-
# material_id = st.sidebar.text_input("Enter Material ID:", value="mp-149")
|
| 1017 |
-
|
| 1018 |
-
# # 3. Choose Cell Type
|
| 1019 |
-
# cell_type = st.sidebar.radio("Unit Cell Type:", ['Primitive Cell', 'Conventional Unit Cell'])
|
| 1020 |
-
|
| 1021 |
-
# # 1. Initialize the session state variable if it doesn't exist
|
| 1022 |
-
# if "atoms" not in st.session_state:
|
| 1023 |
-
# st.session_state.atoms = None
|
| 1024 |
-
|
| 1025 |
-
# if st.sidebar.button("Fetch Structure"):
|
| 1026 |
-
# if not mp_api_key:
|
| 1027 |
-
# st.sidebar.error("Please enter your Materials Project API Key.")
|
| 1028 |
-
# elif not material_id:
|
| 1029 |
-
# st.sidebar.error("Please enter a Material ID.")
|
| 1030 |
-
# else:
|
| 1031 |
-
# try:
|
| 1032 |
-
# with st.spinner(f"Fetching {material_id}..."):
|
| 1033 |
-
# # Connect to MP
|
| 1034 |
-
# with MPRester(mp_api_key) as mpr:
|
| 1035 |
-
# # Get Pymatgen Structure
|
| 1036 |
-
# pmg_structure = mpr.get_structure_by_material_id(material_id)
|
| 1037 |
-
|
| 1038 |
-
# # Handle Symmetry (Primitive vs Conventional)
|
| 1039 |
-
# analyzer = SpacegroupAnalyzer(pmg_structure)
|
| 1040 |
-
|
| 1041 |
-
# if cell_type == 'Conventional Unit Cell':
|
| 1042 |
-
# final_structure = analyzer.get_conventional_standard_structure()
|
| 1043 |
-
# else:
|
| 1044 |
-
# final_structure = analyzer.get_primitive_standard_structure()
|
| 1045 |
-
|
| 1046 |
-
# # Convert Pymatgen Structure -> ASE Atoms
|
| 1047 |
-
# # 2. Store the result in session_state instead of a local variable
|
| 1048 |
-
# st.session_state.atoms = AseAtomsAdaptor.get_atoms(final_structure)
|
| 1049 |
-
|
| 1050 |
-
# st.sidebar.success(f"Loaded {material_id} ({cell_type}) with {len(st.session_state.atoms)} atoms.")
|
| 1051 |
-
|
| 1052 |
-
# except Exception as e:
|
| 1053 |
-
# st.sidebar.error(f"Error fetching data: {str(e)}")
|
| 1054 |
-
# st.session_state.atoms = None
|
| 1055 |
-
|
| 1056 |
-
# # 3. Retrieve the data from session_state for use in the rest of your app
|
| 1057 |
-
# atoms = st.session_state.atoms
|
| 1058 |
-
|
| 1059 |
-
# # --- Rest of your processing code ---
|
| 1060 |
-
# # if atoms is not None:
|
| 1061 |
-
# # st.write(f"System is ready. Formula: {atoms.get_chemical_formula()}")
|
| 1062 |
-
# # else:
|
| 1063 |
-
# # st.info("Please fetch a structure to begin.")
|
| 1064 |
-
# if atoms is not None:
|
| 1065 |
-
# if not hasattr(atoms, 'info'):
|
| 1066 |
-
# atoms.info = {}
|
| 1067 |
-
# atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
|
| 1068 |
-
# atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
|
| 1069 |
-
|
| 1070 |
# --- INITIALIZATION (Must be run first) ---
|
| 1071 |
if "atoms" not in st.session_state:
|
| 1072 |
st.session_state.atoms = None
|
|
|
|
|
|
|
| 1073 |
|
| 1074 |
# Reset atoms state if input method changes, to prevent using old data
|
| 1075 |
# Use a key to track the currently active input method
|
|
@@ -1078,7 +977,7 @@ if 'current_input_method' not in st.session_state:
|
|
| 1078 |
|
| 1079 |
st.sidebar.markdown("## Input Options")
|
| 1080 |
input_method = st.sidebar.radio("Choose Input Method:",
|
| 1081 |
-
["Select Example", "Upload File", "Paste Content", "Materials Project ID"])
|
| 1082 |
|
| 1083 |
# If the input method changes, clear the loaded structure
|
| 1084 |
if input_method != st.session_state.current_input_method:
|
|
@@ -1168,6 +1067,92 @@ elif input_method == "Paste Content":
|
|
| 1168 |
# Clear structure if text area is empty
|
| 1169 |
st.session_state.atoms = None
|
| 1170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1171 |
|
| 1172 |
# --- MATERIALS PROJECT ID ---
|
| 1173 |
elif input_method == "Materials Project ID":
|
|
@@ -1210,6 +1195,52 @@ elif input_method == "Materials Project ID":
|
|
| 1210 |
elif not material_id:
|
| 1211 |
st.sidebar.error("Please enter a Material ID.")
|
| 1212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1213 |
# ----------------------------------------------------
|
| 1214 |
# --- FINAL STRUCTURE RETRIEVAL (The persistent structure) ---
|
| 1215 |
# ----------------------------------------------------
|
|
@@ -1224,8 +1255,6 @@ if atoms is not None:
|
|
| 1224 |
|
| 1225 |
# Display confirmation in the main area (optional, helps the user confirm what's loaded)
|
| 1226 |
st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
|
| 1227 |
-
|
| 1228 |
-
# You can now add your processing buttons here, which will consistently use the 'atoms' variable.
|
| 1229 |
|
| 1230 |
st.sidebar.markdown("## Model Selection")
|
| 1231 |
if mattersim_available:
|
|
@@ -1409,7 +1438,8 @@ if atoms is not None:
|
|
| 1409 |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
| 1410 |
elif model_type == "ORB":
|
| 1411 |
# st.write("Setting up ORB calculator...")
|
| 1412 |
-
orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
|
|
|
|
| 1413 |
calc = ORBCalculator(orbff, device=device)
|
| 1414 |
elif model_type == "MatterSim":
|
| 1415 |
# st.write("Setting up MatterSim calculator...")
|
|
|
|
| 30 |
import pkg_resources
|
| 31 |
from ase.vibrations import Vibrations
|
| 32 |
from mp_api.client import MPRester
|
| 33 |
+
import pubchempy as pcp
|
| 34 |
+
from io import StringIO
|
| 35 |
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
| 36 |
from pymatgen.io.ase import AseAtomsAdaptor
|
| 37 |
+
from pymatgen.core.structure import Molecule
|
| 38 |
import matplotlib.pyplot as plt
|
| 39 |
mattersim_available = True
|
| 40 |
if mattersim_available:
|
|
|
|
| 929 |
}
|
| 930 |
# Define the available ORB models
|
| 931 |
ORB_MODELS = {
|
| 932 |
+
"V3 OMOL Conservative": pretrained.orb_v3_conservative_omol,
|
| 933 |
+
"V3 OMOL Direct": pretrained.orb_v3_direct_omol,
|
| 934 |
+
"V3 OMAT Conservative (inf)": pretrained.orb_v3_conservative_inf_omat,
|
| 935 |
+
"V3 OMAT Conservative (20)": pretrained.orb_v3_conservative_20_omat,
|
| 936 |
+
"V3 OMAT Direct (inf)": pretrained.orb_v3_direct_inf_omat,
|
| 937 |
+
"V3 OMAT Direct (20)": pretrained.orb_v3_direct_20_omat,
|
| 938 |
+
"V3 MPA Conservative (inf)": pretrained.orb_v3_conservative_inf_mpa,
|
| 939 |
+
"V3 MPA Conservative (20)": pretrained.orb_v3_conservative_20_mpa,
|
| 940 |
+
"V3 MPA Direct (inf)": pretrained.orb_v3_direct_inf_mpa,
|
| 941 |
+
"V3 MPA Direct (20)": pretrained.orb_v3_direct_20_mpa,
|
| 942 |
}
|
| 943 |
# Define the available MatterSim models
|
| 944 |
MATTERSIM_MODELS = {
|
|
|
|
| 964 |
calc = FAIRChemCalculator(predictor, task_name="omol")
|
| 965 |
return calc
|
| 966 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 967 |
# --- INITIALIZATION (Must be run first) ---
|
| 968 |
if "atoms" not in st.session_state:
|
| 969 |
st.session_state.atoms = None
|
| 970 |
+
if "atoms_list" not in st.session_state:
|
| 971 |
+
st.session_state.atoms_list = []
|
| 972 |
|
| 973 |
# Reset atoms state if input method changes, to prevent using old data
|
| 974 |
# Use a key to track the currently active input method
|
|
|
|
| 977 |
|
| 978 |
st.sidebar.markdown("## Input Options")
|
| 979 |
input_method = st.sidebar.radio("Choose Input Method:",
|
| 980 |
+
["Select Example", "Upload File", "Paste Content", "Materials Project ID", "PubChem", "Batch Upload"])
|
| 981 |
|
| 982 |
# If the input method changes, clear the loaded structure
|
| 983 |
if input_method != st.session_state.current_input_method:
|
|
|
|
| 1067 |
# Clear structure if text area is empty
|
| 1068 |
st.session_state.atoms = None
|
| 1069 |
|
| 1070 |
+
# --- PUBCHEM SEARCH MODE ---
|
| 1071 |
+
elif input_method == "PubChem":
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
st.sidebar.markdown("### Search PubChem")
|
| 1075 |
+
|
| 1076 |
+
query = st.sidebar.text_input("Enter name or formula (e.g., H2O, water, methane):",
|
| 1077 |
+
key="pubchem_query", value="water")
|
| 1078 |
+
|
| 1079 |
+
# Reset atoms if no query
|
| 1080 |
+
if query.strip() == "":
|
| 1081 |
+
st.session_state.atoms = None
|
| 1082 |
+
|
| 1083 |
+
# Step 1: Search PubChem
|
| 1084 |
+
if query and query.strip():
|
| 1085 |
+
# Avoid re-searching if query is unchanged
|
| 1086 |
+
if "pubchem_last_query" not in st.session_state or st.session_state.pubchem_last_query != query:
|
| 1087 |
+
try:
|
| 1088 |
+
with st.spinner("Searching PubChem..."):
|
| 1089 |
+
results = pcp.get_compounds(query, "name") # name OR formula works
|
| 1090 |
+
st.session_state.pubchem_results = results
|
| 1091 |
+
st.session_state.pubchem_last_query = query
|
| 1092 |
+
except Exception as e:
|
| 1093 |
+
st.sidebar.error(f"Error searching PubChem: {str(e)}")
|
| 1094 |
+
st.session_state.pubchem_results = None
|
| 1095 |
+
|
| 1096 |
+
results = st.session_state.get("pubchem_results", [])
|
| 1097 |
+
if results:
|
| 1098 |
+
# Convert to displayable table
|
| 1099 |
+
df = pd.DataFrame(
|
| 1100 |
+
[(c.cid, c.iupac_name, c.molecular_formula, c.molecular_weight, c.isomeric_smiles)
|
| 1101 |
+
for c in results],
|
| 1102 |
+
columns=["CID", "Name", "Formula", "Weight", "SMILES"]
|
| 1103 |
+
)
|
| 1104 |
+
st.sidebar.success(f"Found {len(df)} result(s).")
|
| 1105 |
+
st.sidebar.dataframe(df)
|
| 1106 |
+
|
| 1107 |
+
# Choose a CID
|
| 1108 |
+
cid = st.sidebar.selectbox("Select CID", df["CID"], key="pubchem_cid")
|
| 1109 |
+
|
| 1110 |
+
# Step 2: Retrieve 3D structure for selected CID
|
| 1111 |
+
if cid:
|
| 1112 |
+
if "pubchem_last_cid" not in st.session_state or st.session_state.pubchem_last_cid != cid:
|
| 1113 |
+
try:
|
| 1114 |
+
with st.spinner("Fetching 3D coordinates..."):
|
| 1115 |
+
# Function to format floating-point numbers with alignment
|
| 1116 |
+
def format_number(num, width=10, precision=5):
|
| 1117 |
+
# Handles positive/negative numbers while maintaining alignment
|
| 1118 |
+
return f"{num: {width}.{precision}f}"
|
| 1119 |
+
# CID to XYZ
|
| 1120 |
+
def generate_xyz_coordinates(cid):
|
| 1121 |
+
compound = pcp.Compound.from_cid(cid, record_type='3d')
|
| 1122 |
+
atoms = compound.atoms
|
| 1123 |
+
coords = [(atom.x, atom.y, atom.z) for atom in atoms]
|
| 1124 |
+
|
| 1125 |
+
num_atoms = len(atoms)
|
| 1126 |
+
xyz_text = f"{num_atoms}\n{compound.cid}\n"
|
| 1127 |
+
|
| 1128 |
+
for atom, coord in zip(atoms, coords):
|
| 1129 |
+
atom_symbol = atom.element
|
| 1130 |
+
x, y, z = coord
|
| 1131 |
+
xyz_text += f"{atom_symbol} {format_number(x, precision=8)} {format_number(y, precision=8)} {format_number(z, precision=8)}\n"
|
| 1132 |
+
|
| 1133 |
+
return xyz_text
|
| 1134 |
+
def get_molecule(cid):
|
| 1135 |
+
xyz_str = generate_xyz_coordinates(cid)
|
| 1136 |
+
return Molecule.from_str(xyz_str, fmt='xyz'), xyz_str
|
| 1137 |
+
# Fetch SDF with 3D conformer
|
| 1138 |
+
# sdf_str = pcp.Compound.from_cid(int(cid)).to_sdf()
|
| 1139 |
+
selected_molecule, xyz_str = get_molecule(cid)
|
| 1140 |
+
|
| 1141 |
+
# Convert SDF → ASE Atoms using temporary memory buffer
|
| 1142 |
+
atoms_to_store = read(StringIO(xyz_str), format="xyz")
|
| 1143 |
+
|
| 1144 |
+
atoms_to_store.info["source_name"] = f"PubChem CID {cid}"
|
| 1145 |
+
st.session_state.atoms = atoms_to_store
|
| 1146 |
+
st.session_state.pubchem_last_cid = cid
|
| 1147 |
+
|
| 1148 |
+
st.sidebar.success(f"Loaded PubChem structure with {len(atoms_to_store)} atoms!")
|
| 1149 |
+
|
| 1150 |
+
except Exception as e:
|
| 1151 |
+
st.sidebar.error(f"Unable to retrieve 3D structure: {str(e)}")
|
| 1152 |
+
st.session_state.atoms = None
|
| 1153 |
+
st.session_state.pubchem_last_cid = None
|
| 1154 |
+
else:
|
| 1155 |
+
st.sidebar.info("No PubChem results found.")
|
| 1156 |
|
| 1157 |
# --- MATERIALS PROJECT ID ---
|
| 1158 |
elif input_method == "Materials Project ID":
|
|
|
|
| 1195 |
elif not material_id:
|
| 1196 |
st.sidebar.error("Please enter a Material ID.")
|
| 1197 |
|
| 1198 |
+
# --- BATCH UPLOAD MULTIPLE FILES ---
|
| 1199 |
+
elif input_method == "Batch Upload":
|
| 1200 |
+
|
| 1201 |
+
uploaded_files = st.sidebar.file_uploader(
|
| 1202 |
+
"Upload multiple structure files",
|
| 1203 |
+
type=["xyz", "cif", "POSCAR", "vasp", "CONTCAR", "mol", "sdf", "tmol", "extxyz"],
|
| 1204 |
+
accept_multiple_files=True
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
# Clear state if no files present
|
| 1208 |
+
if not uploaded_files:
|
| 1209 |
+
st.session_state.atoms_list = []
|
| 1210 |
+
st.session_state.atoms = None
|
| 1211 |
+
|
| 1212 |
+
else:
|
| 1213 |
+
atoms_list = []
|
| 1214 |
+
errors = []
|
| 1215 |
+
|
| 1216 |
+
for file in uploaded_files:
|
| 1217 |
+
try:
|
| 1218 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.name)[1]) as tmp:
|
| 1219 |
+
tmp.write(file.getvalue())
|
| 1220 |
+
tmp_path = tmp.name
|
| 1221 |
+
|
| 1222 |
+
atoms_obj = read(tmp_path)
|
| 1223 |
+
atoms_obj.info["source_name"] = file.name
|
| 1224 |
+
atoms_list.append(atoms_obj)
|
| 1225 |
+
|
| 1226 |
+
except Exception as e:
|
| 1227 |
+
errors.append(f"{file.name}: {str(e)}")
|
| 1228 |
+
|
| 1229 |
+
finally:
|
| 1230 |
+
if "tmp_path" in locals() and os.path.exists(tmp_path):
|
| 1231 |
+
os.unlink(tmp_path)
|
| 1232 |
+
|
| 1233 |
+
# Store everything only if at least one success
|
| 1234 |
+
if atoms_list:
|
| 1235 |
+
st.session_state.atoms_list = atoms_list
|
| 1236 |
+
st.session_state.atoms = atoms_list[0] # default: first item
|
| 1237 |
+
st.sidebar.success(f"Loaded {len(atoms_list)} structures successfully!")
|
| 1238 |
+
|
| 1239 |
+
if len(atoms_list) > 1:
|
| 1240 |
+
st.sidebar.info("You can now process them as a batch.")
|
| 1241 |
+
|
| 1242 |
+
if errors:
|
| 1243 |
+
st.sidebar.error("Some files could not be loaded:\n" + "\n".join(errors))
|
| 1244 |
# ----------------------------------------------------
|
| 1245 |
# --- FINAL STRUCTURE RETRIEVAL (The persistent structure) ---
|
| 1246 |
# ----------------------------------------------------
|
|
|
|
| 1255 |
|
| 1256 |
# Display confirmation in the main area (optional, helps the user confirm what's loaded)
|
| 1257 |
st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
|
|
|
|
|
|
|
| 1258 |
|
| 1259 |
st.sidebar.markdown("## Model Selection")
|
| 1260 |
if mattersim_available:
|
|
|
|
| 1438 |
calc = get_fairchem_model(selected_model, model_path, device, selected_task_type)
|
| 1439 |
elif model_type == "ORB":
|
| 1440 |
# st.write("Setting up ORB calculator...")
|
| 1441 |
+
# orbff = pretrained.orb_v3_conservative_inf_omat(device=device, precision=selected_default_dtype)
|
| 1442 |
+
orbff = model_path(device=device, precision=selected_default_dtype)
|
| 1443 |
calc = ORBCalculator(orbff, device=device)
|
| 1444 |
elif model_type == "MatterSim":
|
| 1445 |
# st.write("Setting up MatterSim calculator...")
|