Spaces:
Running
Running
add materials project ID support
Browse files- src/streamlit_app.py +461 -39
src/streamlit_app.py
CHANGED
|
@@ -4,13 +4,16 @@ import io
|
|
| 4 |
import tempfile
|
| 5 |
import torch
|
| 6 |
# FOR CPU only mode
|
| 7 |
-
#
|
| 8 |
# Or disable compilation entirely
|
| 9 |
# torch.backends.cudnn.enabled = False
|
| 10 |
import numpy as np
|
| 11 |
from ase import Atoms
|
| 12 |
from ase.io import read, write
|
| 13 |
from ase.optimize import BFGS, LBFGS, FIRE
|
|
|
|
|
|
|
|
|
|
| 14 |
from ase.constraints import FixAtoms
|
| 15 |
from ase.filters import FrechetCellFilter
|
| 16 |
from ase.visualize import view
|
|
@@ -26,6 +29,9 @@ import subprocess
|
|
| 26 |
import sys
|
| 27 |
import pkg_resources
|
| 28 |
from ase.vibrations import Vibrations
|
|
|
|
|
|
|
|
|
|
| 29 |
import matplotlib.pyplot as plt
|
| 30 |
mattersim_available = True
|
| 31 |
if mattersim_available:
|
|
@@ -894,6 +900,7 @@ def check_atom_limit(atoms_obj, selected_model):
|
|
| 894 |
MACE_MODELS = {
|
| 895 |
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
|
| 896 |
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
|
|
|
|
| 897 |
"MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
|
| 898 |
"MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
|
| 899 |
"MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
|
|
@@ -952,59 +959,273 @@ def get_fairchem_model(selected_model_name, model_path_or_name, device, selected
|
|
| 952 |
calc = FAIRChemCalculator(predictor, task_name="omol")
|
| 953 |
return calc
|
| 954 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
st.sidebar.markdown("## Input Options")
|
| 956 |
-
input_method = st.sidebar.radio("Choose Input Method:",
|
| 957 |
-
|
| 958 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
if input_method == "Upload File":
|
| 960 |
uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
|
|
|
|
|
|
|
| 961 |
if uploaded_file:
|
| 962 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
| 963 |
-
tmp_file.write(uploaded_file.getvalue())
|
| 964 |
-
tmp_filepath = tmp_file.name
|
| 965 |
try:
|
| 966 |
-
|
| 967 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
except Exception as e:
|
| 969 |
st.sidebar.error(f"Error loading file: {str(e)}")
|
|
|
|
|
|
|
| 970 |
finally:
|
|
|
|
| 971 |
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
| 972 |
os.unlink(tmp_filepath)
|
|
|
|
|
|
|
|
|
|
| 973 |
|
|
|
|
| 974 |
elif input_method == "Select Example":
|
|
|
|
| 975 |
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
|
| 976 |
-
|
|
|
|
|
|
|
| 977 |
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
|
| 978 |
try:
|
| 979 |
-
|
| 980 |
-
|
|
|
|
|
|
|
| 981 |
except Exception as e:
|
| 982 |
st.sidebar.error(f"Error loading example: {str(e)}")
|
|
|
|
| 983 |
|
|
|
|
| 984 |
elif input_method == "Paste Content":
|
| 985 |
file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
|
| 986 |
-
content = st.sidebar.text_area("Paste file content here:", height=200)
|
|
|
|
|
|
|
|
|
|
| 987 |
if content:
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
|
| 1002 |
if atoms is not None:
|
| 1003 |
if not hasattr(atoms, 'info'):
|
| 1004 |
atoms.info = {}
|
| 1005 |
atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
|
| 1006 |
atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
|
| 1007 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1008 |
|
| 1009 |
st.sidebar.markdown("## Model Selection")
|
| 1010 |
if mattersim_available:
|
|
@@ -1044,7 +1265,6 @@ if model_type == "FairChem":
|
|
| 1044 |
else:
|
| 1045 |
atoms.info["charge"] = 0
|
| 1046 |
atoms.info["spin"] = 1 # FairChem expects multiplicity
|
| 1047 |
-
|
| 1048 |
if model_type == "ORB":
|
| 1049 |
selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
|
| 1050 |
model_path = ORB_MODELS[selected_model]
|
|
@@ -1076,19 +1296,43 @@ st.sidebar.markdown("## Task Selection")
|
|
| 1076 |
task = st.sidebar.selectbox("Select Calculation Task:",
|
| 1077 |
["Energy Calculation",
|
| 1078 |
"Energy + Forces Calculation",
|
| 1079 |
-
"Atomization/Cohesive Energy",
|
| 1080 |
"Geometry Optimization",
|
| 1081 |
"Cell + Geometry Optimization",
|
|
|
|
| 1082 |
"Vibrational Mode Analysis",
|
| 1083 |
#"Phonons"
|
| 1084 |
-
|
| 1085 |
|
| 1086 |
if "Optimization" in task:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1087 |
st.sidebar.markdown("### Optimization Parameters")
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1091 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1092 |
|
| 1093 |
if "Vibration" in task:
|
| 1094 |
st.write("### Thermodynamic Quantities (Molecule Only)")
|
|
@@ -1131,7 +1375,7 @@ if atoms is not None:
|
|
| 1131 |
st.markdown("### Selected Task")
|
| 1132 |
st.write(f"**Task:** {task}")
|
| 1133 |
|
| 1134 |
-
if "Optimization" in task:
|
| 1135 |
st.write(f"**Max Steps:** {max_steps}")
|
| 1136 |
st.write(f"**Convergence Threshold:** {fmax} eV/Å")
|
| 1137 |
st.write(f"**Optimizer:** {optimizer_type}")
|
|
@@ -1270,7 +1514,7 @@ if atoms is not None:
|
|
| 1270 |
results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV"
|
| 1271 |
results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
|
| 1272 |
|
| 1273 |
-
elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
|
| 1274 |
is_periodic = any(calc_atoms.pbc)
|
| 1275 |
opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
|
| 1276 |
# Create temporary trajectory file
|
|
@@ -1402,6 +1646,188 @@ if atoms is not None:
|
|
| 1402 |
)
|
| 1403 |
|
| 1404 |
show_trajectory_and_controls()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1405 |
elif task == "Vibrational Mode Analysis":
|
| 1406 |
# Conversion factors
|
| 1407 |
from ase.units import kB as kB_eVK, _Nav, J # ASE's constants
|
|
@@ -1422,20 +1848,15 @@ if atoms is not None:
|
|
| 1422 |
vib.run()
|
| 1423 |
freqs = vib.get_frequencies()
|
| 1424 |
energies = vib.get_energies()
|
| 1425 |
-
|
| 1426 |
|
| 1427 |
-
|
| 1428 |
print('\n\n\n\n\n\n\n\n')
|
| 1429 |
# vib.get_hessian_2d()
|
| 1430 |
# st.write(vib.summary())
|
| 1431 |
# print('\n')
|
| 1432 |
# vib.tabulate()
|
| 1433 |
-
|
| 1434 |
|
| 1435 |
-
|
| 1436 |
freqs_cm = freqs
|
| 1437 |
freqs_eV = energies
|
| 1438 |
-
|
| 1439 |
|
| 1440 |
# Classify frequencies
|
| 1441 |
mode_data = []
|
|
@@ -1601,6 +2022,7 @@ with st.expander('ℹ️ About This App & Foundational MLIPs'):
|
|
| 1601 |
- For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
|
| 1602 |
""")
|
| 1603 |
|
|
|
|
| 1604 |
with st.expander('🔧 Tech Stack & System Information'):
|
| 1605 |
import platform
|
| 1606 |
import psutil
|
|
|
|
| 4 |
import tempfile
|
| 5 |
import torch
|
| 6 |
# FOR CPU only mode
|
| 7 |
+
#torch._dynamo.config.suppress_errors = True
|
| 8 |
# Or disable compilation entirely
|
| 9 |
# torch.backends.cudnn.enabled = False
|
| 10 |
import numpy as np
|
| 11 |
from ase import Atoms
|
| 12 |
from ase.io import read, write
|
| 13 |
from ase.optimize import BFGS, LBFGS, FIRE
|
| 14 |
+
from ase.optimize.basin import BasinHopping
|
| 15 |
+
from ase.optimize.minimahopping import MinimaHopping
|
| 16 |
+
from ase.units import kB
|
| 17 |
from ase.constraints import FixAtoms
|
| 18 |
from ase.filters import FrechetCellFilter
|
| 19 |
from ase.visualize import view
|
|
|
|
| 29 |
import sys
|
| 30 |
import pkg_resources
|
| 31 |
from ase.vibrations import Vibrations
|
| 32 |
+
from mp_api.client import MPRester
|
| 33 |
+
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
| 34 |
+
from pymatgen.io.ase import AseAtomsAdaptor
|
| 35 |
import matplotlib.pyplot as plt
|
| 36 |
mattersim_available = True
|
| 37 |
if mattersim_available:
|
|
|
|
| 900 |
MACE_MODELS = {
|
| 901 |
"MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
|
| 902 |
"MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
|
| 903 |
+
"MACE OMAT Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-small.model",
|
| 904 |
"MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
|
| 905 |
"MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
|
| 906 |
"MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
|
|
|
|
| 959 |
calc = FAIRChemCalculator(predictor, task_name="omol")
|
| 960 |
return calc
|
| 961 |
|
| 962 |
+
# st.sidebar.markdown("## Input Options")
|
| 963 |
+
# input_method = st.sidebar.radio("Choose Input Method:",
|
| 964 |
+
# ["Select Example", "Upload File", "Paste Content", "Materials Project ID"])
|
| 965 |
+
# atoms = None
|
| 966 |
+
|
| 967 |
+
# if input_method == "Upload File":
|
| 968 |
+
# uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
|
| 969 |
+
# if uploaded_file:
|
| 970 |
+
# with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
| 971 |
+
# tmp_file.write(uploaded_file.getvalue())
|
| 972 |
+
# tmp_filepath = tmp_file.name
|
| 973 |
+
# try:
|
| 974 |
+
# atoms = read(tmp_filepath)
|
| 975 |
+
# st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
|
| 976 |
+
# except Exception as e:
|
| 977 |
+
# st.sidebar.error(f"Error loading file: {str(e)}")
|
| 978 |
+
# finally:
|
| 979 |
+
# if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
| 980 |
+
# os.unlink(tmp_filepath)
|
| 981 |
+
|
| 982 |
+
# elif input_method == "Select Example":
|
| 983 |
+
# example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
|
| 984 |
+
# if example_name:
|
| 985 |
+
# file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
|
| 986 |
+
# try:
|
| 987 |
+
# atoms = read(file_path)
|
| 988 |
+
# st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!")
|
| 989 |
+
# except Exception as e:
|
| 990 |
+
# st.sidebar.error(f"Error loading example: {str(e)}")
|
| 991 |
+
|
| 992 |
+
# elif input_method == "Paste Content":
|
| 993 |
+
# file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
|
| 994 |
+
# content = st.sidebar.text_area("Paste file content here:", height=200)
|
| 995 |
+
# if content:
|
| 996 |
+
# try:
|
| 997 |
+
# suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
|
| 998 |
+
# suffix = suffix_map.get(file_format, ".xyz")
|
| 999 |
+
# with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
| 1000 |
+
# tmp_file.write(content.encode())
|
| 1001 |
+
# tmp_filepath = tmp_file.name
|
| 1002 |
+
# atoms = read(tmp_filepath)
|
| 1003 |
+
# st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
|
| 1004 |
+
# except Exception as e:
|
| 1005 |
+
# st.sidebar.error(f"Error parsing content: {str(e)}")
|
| 1006 |
+
# finally:
|
| 1007 |
+
# if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
| 1008 |
+
# os.unlink(tmp_filepath)
|
| 1009 |
+
# elif input_method == "Materials Project ID":
|
| 1010 |
+
# # st.sidebar.info("Requires an API Key from [Materials Project](https://next-gen.materialsproject.org/api)")
|
| 1011 |
+
|
| 1012 |
+
# # 1. Get API Key (You can also set this as an env variable)
|
| 1013 |
+
# # mp_api_key = st.sidebar.text_input("MP API Key", type="password")
|
| 1014 |
+
# mp_api_key = os.getenv("MP_API_KEY")
|
| 1015 |
+
# # 2. Get Material ID
|
| 1016 |
+
# material_id = st.sidebar.text_input("Enter Material ID:", value="mp-149")
|
| 1017 |
+
|
| 1018 |
+
# # 3. Choose Cell Type
|
| 1019 |
+
# cell_type = st.sidebar.radio("Unit Cell Type:", ['Primitive Cell', 'Conventional Unit Cell'])
|
| 1020 |
+
|
| 1021 |
+
# # 1. Initialize the session state variable if it doesn't exist
|
| 1022 |
+
# if "atoms" not in st.session_state:
|
| 1023 |
+
# st.session_state.atoms = None
|
| 1024 |
+
|
| 1025 |
+
# if st.sidebar.button("Fetch Structure"):
|
| 1026 |
+
# if not mp_api_key:
|
| 1027 |
+
# st.sidebar.error("Please enter your Materials Project API Key.")
|
| 1028 |
+
# elif not material_id:
|
| 1029 |
+
# st.sidebar.error("Please enter a Material ID.")
|
| 1030 |
+
# else:
|
| 1031 |
+
# try:
|
| 1032 |
+
# with st.spinner(f"Fetching {material_id}..."):
|
| 1033 |
+
# # Connect to MP
|
| 1034 |
+
# with MPRester(mp_api_key) as mpr:
|
| 1035 |
+
# # Get Pymatgen Structure
|
| 1036 |
+
# pmg_structure = mpr.get_structure_by_material_id(material_id)
|
| 1037 |
+
|
| 1038 |
+
# # Handle Symmetry (Primitive vs Conventional)
|
| 1039 |
+
# analyzer = SpacegroupAnalyzer(pmg_structure)
|
| 1040 |
+
|
| 1041 |
+
# if cell_type == 'Conventional Unit Cell':
|
| 1042 |
+
# final_structure = analyzer.get_conventional_standard_structure()
|
| 1043 |
+
# else:
|
| 1044 |
+
# final_structure = analyzer.get_primitive_standard_structure()
|
| 1045 |
+
|
| 1046 |
+
# # Convert Pymatgen Structure -> ASE Atoms
|
| 1047 |
+
# # 2. Store the result in session_state instead of a local variable
|
| 1048 |
+
# st.session_state.atoms = AseAtomsAdaptor.get_atoms(final_structure)
|
| 1049 |
+
|
| 1050 |
+
# st.sidebar.success(f"Loaded {material_id} ({cell_type}) with {len(st.session_state.atoms)} atoms.")
|
| 1051 |
+
|
| 1052 |
+
# except Exception as e:
|
| 1053 |
+
# st.sidebar.error(f"Error fetching data: {str(e)}")
|
| 1054 |
+
# st.session_state.atoms = None
|
| 1055 |
+
|
| 1056 |
+
# # 3. Retrieve the data from session_state for use in the rest of your app
|
| 1057 |
+
# atoms = st.session_state.atoms
|
| 1058 |
+
|
| 1059 |
+
# # --- Rest of your processing code ---
|
| 1060 |
+
# # if atoms is not None:
|
| 1061 |
+
# # st.write(f"System is ready. Formula: {atoms.get_chemical_formula()}")
|
| 1062 |
+
# # else:
|
| 1063 |
+
# # st.info("Please fetch a structure to begin.")
|
| 1064 |
+
# if atoms is not None:
|
| 1065 |
+
# if not hasattr(atoms, 'info'):
|
| 1066 |
+
# atoms.info = {}
|
| 1067 |
+
# atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
|
| 1068 |
+
# atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
|
| 1069 |
+
|
| 1070 |
+
# --- INITIALIZATION (Must be run first) ---
|
| 1071 |
+
if "atoms" not in st.session_state:
|
| 1072 |
+
st.session_state.atoms = None
|
| 1073 |
+
|
| 1074 |
+
# Reset atoms state if input method changes, to prevent using old data
|
| 1075 |
+
# Use a key to track the currently active input method
|
| 1076 |
+
if 'current_input_method' not in st.session_state:
|
| 1077 |
+
st.session_state.current_input_method = "Select Example"
|
| 1078 |
+
|
| 1079 |
st.sidebar.markdown("## Input Options")
|
| 1080 |
+
input_method = st.sidebar.radio("Choose Input Method:",
|
| 1081 |
+
["Select Example", "Upload File", "Paste Content", "Materials Project ID"])
|
| 1082 |
|
| 1083 |
+
# If the input method changes, clear the loaded structure
|
| 1084 |
+
if input_method != st.session_state.current_input_method:
|
| 1085 |
+
st.session_state.atoms = None
|
| 1086 |
+
st.session_state.current_input_method = input_method
|
| 1087 |
+
|
| 1088 |
+
# --- UPLOAD FILE ---
|
| 1089 |
if input_method == "Upload File":
|
| 1090 |
uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
|
| 1091 |
+
|
| 1092 |
+
# Load immediately upon file upload/change (no button needed)
|
| 1093 |
if uploaded_file:
|
|
|
|
|
|
|
|
|
|
| 1094 |
try:
|
| 1095 |
+
# Check if this file content has already been loaded to prevent redundant temp file operations
|
| 1096 |
+
if 'uploaded_file_hash' not in st.session_state or st.session_state.uploaded_file_hash != uploaded_file.name:
|
| 1097 |
+
|
| 1098 |
+
# Use tempfile to handle the uploaded file content
|
| 1099 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
|
| 1100 |
+
tmp_file.write(uploaded_file.getvalue())
|
| 1101 |
+
tmp_filepath = tmp_file.name
|
| 1102 |
+
|
| 1103 |
+
atoms_to_store = read(tmp_filepath)
|
| 1104 |
+
st.session_state.atoms = atoms_to_store
|
| 1105 |
+
st.session_state.uploaded_file_hash = uploaded_file.name # Track the loaded file
|
| 1106 |
+
st.sidebar.success(f"Successfully loaded structure with {len(atoms_to_store)} atoms!")
|
| 1107 |
+
|
| 1108 |
except Exception as e:
|
| 1109 |
st.sidebar.error(f"Error loading file: {str(e)}")
|
| 1110 |
+
st.session_state.atoms = None
|
| 1111 |
+
st.session_state.uploaded_file_hash = None # Clear hash on failure
|
| 1112 |
finally:
|
| 1113 |
+
# Clean up the temporary file
|
| 1114 |
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
| 1115 |
os.unlink(tmp_filepath)
|
| 1116 |
+
else:
|
| 1117 |
+
# Clear structure if file uploader is empty
|
| 1118 |
+
st.session_state.atoms = None
|
| 1119 |
|
| 1120 |
+
# --- SELECT EXAMPLE ---
|
| 1121 |
elif input_method == "Select Example":
|
| 1122 |
+
# Load immediately upon selection change (no button needed)
|
| 1123 |
example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
|
| 1124 |
+
|
| 1125 |
+
# Only load if a valid example is selected and it's different from the current state
|
| 1126 |
+
if example_name and (st.session_state.atoms is None or st.session_state.atoms.info.get('source_name') != example_name):
|
| 1127 |
file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
|
| 1128 |
try:
|
| 1129 |
+
atoms_to_store = read(file_path)
|
| 1130 |
+
atoms_to_store.info['source_name'] = example_name # Add a tag for tracking
|
| 1131 |
+
st.session_state.atoms = atoms_to_store
|
| 1132 |
+
st.sidebar.success(f"Loaded {example_name} with {len(atoms_to_store)} atoms!")
|
| 1133 |
except Exception as e:
|
| 1134 |
st.sidebar.error(f"Error loading example: {str(e)}")
|
| 1135 |
+
st.session_state.atoms = None
|
| 1136 |
|
| 1137 |
+
# --- PASTE CONTENT ---
|
| 1138 |
elif input_method == "Paste Content":
|
| 1139 |
file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
|
| 1140 |
+
content = st.sidebar.text_area("Paste file content here:", height=200, key="paste_content_input")
|
| 1141 |
+
|
| 1142 |
+
# Load immediately upon content change (no button needed)
|
| 1143 |
+
# Check if content is present and is different from the last successfully parsed content
|
| 1144 |
if content:
|
| 1145 |
+
# Simple check to avoid parsing on every single character change
|
| 1146 |
+
if 'last_parsed_content' not in st.session_state or st.session_state.last_parsed_content != content:
|
| 1147 |
+
|
| 1148 |
+
try:
|
| 1149 |
+
suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
|
| 1150 |
+
suffix = suffix_map.get(file_format, ".xyz")
|
| 1151 |
+
|
| 1152 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
|
| 1153 |
+
tmp_file.write(content.encode())
|
| 1154 |
+
tmp_filepath = tmp_file.name
|
| 1155 |
+
|
| 1156 |
+
atoms_to_store = read(tmp_filepath)
|
| 1157 |
+
st.session_state.atoms = atoms_to_store
|
| 1158 |
+
st.session_state.last_parsed_content = content # Track the parsed content
|
| 1159 |
+
st.sidebar.success(f"Successfully parsed structure with {len(atoms_to_store)} atoms!")
|
| 1160 |
+
except Exception as e:
|
| 1161 |
+
st.sidebar.error(f"Error parsing content: {str(e)}")
|
| 1162 |
+
st.session_state.atoms = None
|
| 1163 |
+
st.session_state.last_parsed_content = None
|
| 1164 |
+
finally:
|
| 1165 |
+
if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
|
| 1166 |
+
os.unlink(tmp_filepath)
|
| 1167 |
+
else:
|
| 1168 |
+
# Clear structure if text area is empty
|
| 1169 |
+
st.session_state.atoms = None
|
| 1170 |
+
|
| 1171 |
+
|
| 1172 |
+
# --- MATERIALS PROJECT ID ---
|
| 1173 |
+
elif input_method == "Materials Project ID":
|
| 1174 |
+
mp_api_key = os.getenv("MP_API_KEY")
|
| 1175 |
+
material_id = st.sidebar.text_input("Enter Material ID:", value="mp-149", key="mp_id_input")
|
| 1176 |
+
cell_type = st.sidebar.radio("Unit Cell Type:", ['Primitive Cell', 'Conventional Unit Cell'], key="cell_type_radio")
|
| 1177 |
+
|
| 1178 |
+
# Reactive Loading (No button needed)
|
| 1179 |
+
# Check for valid inputs and if the current material_id/cell_type is different from the loaded one
|
| 1180 |
+
if mp_api_key and material_id:
|
| 1181 |
+
|
| 1182 |
+
# Simple tracking to avoid API call if nothing has changed
|
| 1183 |
+
current_mp_key = f"{material_id}_{cell_type}"
|
| 1184 |
+
if 'last_fetched_mp_key' not in st.session_state or st.session_state.last_fetched_mp_key != current_mp_key:
|
| 1185 |
+
|
| 1186 |
+
try:
|
| 1187 |
+
with st.spinner(f"Fetching {material_id}..."):
|
| 1188 |
+
with MPRester(mp_api_key) as mpr:
|
| 1189 |
+
pmg_structure = mpr.get_structure_by_material_id(material_id)
|
| 1190 |
+
analyzer = SpacegroupAnalyzer(pmg_structure)
|
| 1191 |
+
|
| 1192 |
+
if cell_type == 'Conventional Unit Cell':
|
| 1193 |
+
final_structure = analyzer.get_conventional_standard_structure()
|
| 1194 |
+
else:
|
| 1195 |
+
final_structure = analyzer.get_primitive_standard_structure()
|
| 1196 |
+
|
| 1197 |
+
atoms_to_store = AseAtomsAdaptor.get_atoms(final_structure)
|
| 1198 |
+
st.session_state.atoms = atoms_to_store
|
| 1199 |
+
st.session_state.last_fetched_mp_key = current_mp_key # Update tracking key
|
| 1200 |
+
st.sidebar.success(f"Loaded {material_id} ({cell_type}) with {len(st.session_state.atoms)} atoms.")
|
| 1201 |
+
|
| 1202 |
+
except Exception as e:
|
| 1203 |
+
st.sidebar.error(f"Error fetching data: {str(e)}")
|
| 1204 |
+
st.session_state.atoms = None
|
| 1205 |
+
st.session_state.last_fetched_mp_key = None # Clear key on failure
|
| 1206 |
+
|
| 1207 |
+
# Handle error messages when inputs are missing
|
| 1208 |
+
elif not mp_api_key:
|
| 1209 |
+
st.sidebar.error("Please set your Materials Project API Key (MP_API_KEY environment variable).")
|
| 1210 |
+
elif not material_id:
|
| 1211 |
+
st.sidebar.error("Please enter a Material ID.")
|
| 1212 |
+
|
| 1213 |
+
# ----------------------------------------------------
|
| 1214 |
+
# --- FINAL STRUCTURE RETRIEVAL (The persistent structure) ---
|
| 1215 |
+
# ----------------------------------------------------
|
| 1216 |
+
# This is the single source of truth for the rest of your app
|
| 1217 |
+
atoms = st.session_state.atoms
|
| 1218 |
|
| 1219 |
if atoms is not None:
|
| 1220 |
if not hasattr(atoms, 'info'):
|
| 1221 |
atoms.info = {}
|
| 1222 |
atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
|
| 1223 |
atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
|
| 1224 |
+
|
| 1225 |
+
# Display confirmation in the main area (optional, helps the user confirm what's loaded)
|
| 1226 |
+
st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
|
| 1227 |
+
|
| 1228 |
+
# You can now add your processing buttons here, which will consistently use the 'atoms' variable.
|
| 1229 |
|
| 1230 |
st.sidebar.markdown("## Model Selection")
|
| 1231 |
if mattersim_available:
|
|
|
|
| 1265 |
else:
|
| 1266 |
atoms.info["charge"] = 0
|
| 1267 |
atoms.info["spin"] = 1 # FairChem expects multiplicity
|
|
|
|
| 1268 |
if model_type == "ORB":
|
| 1269 |
selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
|
| 1270 |
model_path = ORB_MODELS[selected_model]
|
|
|
|
| 1296 |
task = st.sidebar.selectbox("Select Calculation Task:",
|
| 1297 |
["Energy Calculation",
|
| 1298 |
"Energy + Forces Calculation",
|
| 1299 |
+
"Atomization/Cohesive Energy",
|
| 1300 |
"Geometry Optimization",
|
| 1301 |
"Cell + Geometry Optimization",
|
| 1302 |
+
#"Global Optimization",
|
| 1303 |
"Vibrational Mode Analysis",
|
| 1304 |
#"Phonons"
|
| 1305 |
+
])
|
| 1306 |
|
| 1307 |
if "Optimization" in task:
|
| 1308 |
+
# st.sidebar.markdown("### Optimization Parameters")
|
| 1309 |
+
# max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps
|
| 1310 |
+
# fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax
|
| 1311 |
+
# optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type
|
| 1312 |
st.sidebar.markdown("### Optimization Parameters")
|
| 1313 |
+
|
| 1314 |
+
# 1. Configuration for GLOBAL Optimization
|
| 1315 |
+
if task == "Global Optimization":
|
| 1316 |
+
global_method = st.sidebar.selectbox("Method:", ["Basin Hopping", "Minima Hopping"])
|
| 1317 |
+
|
| 1318 |
+
# Common parameters
|
| 1319 |
+
temperature_K = st.sidebar.number_input("Temperature (K):", min_value=10.0, max_value=2000.0, value=300.0, step=10.0)
|
| 1320 |
+
global_steps = st.sidebar.number_input("Search Steps:", min_value=10, max_value=500, value=50, step=10)
|
| 1321 |
+
# Basin Hopping specific
|
| 1322 |
+
if global_method == "Basin Hopping":
|
| 1323 |
+
dr_amp = st.sidebar.number_input("Displacement Amplitude (Å):", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
|
| 1324 |
+
fmax_local = st.sidebar.number_input("Local Relaxation Threshold (eV/Å):", value=0.05, format="%.3f")
|
| 1325 |
+
|
| 1326 |
+
# Minima Hopping specific
|
| 1327 |
+
elif global_method == "Minima Hopping":
|
| 1328 |
+
st.sidebar.caption("Minima Hopping automates threshold adjustments to escape local minima.")
|
| 1329 |
+
fmax_local = st.sidebar.number_input("Local Relaxation Threshold (eV/Å):", value=0.05, format="%.3f")
|
| 1330 |
|
| 1331 |
+
# 2. Configuration for LOCAL/CELL Optimization
|
| 1332 |
+
else:
|
| 1333 |
+
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1)
|
| 1334 |
+
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f")
|
| 1335 |
+
optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1)
|
| 1336 |
|
| 1337 |
if "Vibration" in task:
|
| 1338 |
st.write("### Thermodynamic Quantities (Molecule Only)")
|
|
|
|
| 1375 |
st.markdown("### Selected Task")
|
| 1376 |
st.write(f"**Task:** {task}")
|
| 1377 |
|
| 1378 |
+
if "Geometry Optimization" in task:
|
| 1379 |
st.write(f"**Max Steps:** {max_steps}")
|
| 1380 |
st.write(f"**Convergence Threshold:** {fmax} eV/Å")
|
| 1381 |
st.write(f"**Optimizer:** {optimizer_type}")
|
|
|
|
| 1514 |
results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV"
|
| 1515 |
results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
|
| 1516 |
|
| 1517 |
+
elif "Geometry Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
|
| 1518 |
is_periodic = any(calc_atoms.pbc)
|
| 1519 |
opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
|
| 1520 |
# Create temporary trajectory file
|
|
|
|
| 1646 |
)
|
| 1647 |
|
| 1648 |
show_trajectory_and_controls()
|
| 1649 |
+
elif task == "Global Optimization":
|
| 1650 |
+
st.info(f"Starting Global Optimization using {global_method}...")
|
| 1651 |
+
|
| 1652 |
+
# Create temporary trajectory file to store the "hopping" steps
|
| 1653 |
+
traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
|
| 1654 |
+
|
| 1655 |
+
# Container for live updates
|
| 1656 |
+
log_container = st.empty()
|
| 1657 |
+
global_min_energy = 0
|
| 1658 |
+
|
| 1659 |
+
def global_log(opt_instance):
|
| 1660 |
+
"""Helper to log global optimization steps."""
|
| 1661 |
+
global global_min_energy
|
| 1662 |
+
current_e = opt_instance.atoms.get_potential_energy()
|
| 1663 |
+
# For BasinHopping, nsteps is available. For others, we might need a counter.
|
| 1664 |
+
step = getattr(opt_instance, 'nsteps', 'N/A')
|
| 1665 |
+
log_container.write(f"Global Step: {step} | Energy: {current_e:.6f} eV")
|
| 1666 |
+
if current_e < global_min_energy:
|
| 1667 |
+
global_min_energy = current_e
|
| 1668 |
+
|
| 1669 |
+
if global_method == "Basin Hopping":
|
| 1670 |
+
# Basin Hopping requires Temperature in eV (kB * T)
|
| 1671 |
+
kT = temperature_K * kB
|
| 1672 |
+
|
| 1673 |
+
# Create the wrapper for the hack needed to enforce the optimization to stop when it reaches a certain number of steps
|
| 1674 |
+
class LimitedLBFGS(LBFGS):
|
| 1675 |
+
def run(self, fmax=0.05, steps=None):
|
| 1676 |
+
# 'steps' here overrides whatever BasinHopping tries to do.
|
| 1677 |
+
# Set your desired max local steps (e.g., 200)
|
| 1678 |
+
return super().run(fmax=fmax, steps=100)
|
| 1679 |
+
# Initialize Basin Hopping with the trajectory file
|
| 1680 |
+
bh = BasinHopping(calc_atoms,
|
| 1681 |
+
temperature=kT,
|
| 1682 |
+
dr=dr_amp,
|
| 1683 |
+
optimizer=LimitedLBFGS,
|
| 1684 |
+
fmax=fmax_local,
|
| 1685 |
+
trajectory=traj_filename) # Log steps to file automatically
|
| 1686 |
+
|
| 1687 |
+
# Attach the live logger
|
| 1688 |
+
bh.attach(lambda: global_log(bh), interval=1)
|
| 1689 |
+
|
| 1690 |
+
# Run the optimization
|
| 1691 |
+
bh.run(global_steps)
|
| 1692 |
+
|
| 1693 |
+
results["Global Minimum Energy"] = f"{global_min_energy:.6f} eV"
|
| 1694 |
+
results["Steps Taken"] = global_steps
|
| 1695 |
+
results["Converged"] = "N/A (Global Search)"
|
| 1696 |
+
|
| 1697 |
+
elif global_method == "Minima Hopping":
|
| 1698 |
+
# Minima Hopping manages its own internal optimizers and doesn't accept a 'trajectory'
|
| 1699 |
+
# file argument in the same way BasinHopping does in __init__.
|
| 1700 |
+
opt = MinimaHopping(calc_atoms,
|
| 1701 |
+
T0=temperature_K,
|
| 1702 |
+
fmax=fmax_local,
|
| 1703 |
+
optimizer=LBFGS)
|
| 1704 |
+
|
| 1705 |
+
# We run it. Live logging is harder here without subclassing,
|
| 1706 |
+
# so we rely on the final output for the trajectory.
|
| 1707 |
+
opt(totalsteps=global_steps)
|
| 1708 |
+
|
| 1709 |
+
results["Current Energy"] = f"{calc_atoms.get_potential_energy():.6f} eV"
|
| 1710 |
+
|
| 1711 |
+
# Post-processing: MinimaHopping stores visited minima in an internal list usually.
|
| 1712 |
+
# We explicitly write the found minima to the trajectory file so the visualizer below works.
|
| 1713 |
+
# Note: opt.minima is a list of Atoms objects found during the hop.
|
| 1714 |
+
if hasattr(opt, 'minima'):
|
| 1715 |
+
from ase.io import write
|
| 1716 |
+
write(traj_filename, opt.minima)
|
| 1717 |
+
else:
|
| 1718 |
+
# Fallback if specific version doesn't store list, just save final
|
| 1719 |
+
write(traj_filename, calc_atoms)
|
| 1720 |
+
|
| 1721 |
+
st.success("Global Optimization Complete!")
|
| 1722 |
+
|
| 1723 |
+
st.markdown("### Results")
|
| 1724 |
+
for key, value in results.items():
|
| 1725 |
+
st.write(f"**{key}:** {value}")
|
| 1726 |
+
|
| 1727 |
+
# --- Visualization and Downloading (Fragmented) ---
|
| 1728 |
+
|
| 1729 |
+
# 1. Clean up the temp file path for reading
|
| 1730 |
+
# We define the visualizer function using @st.fragment to prevent full re-runs
|
| 1731 |
+
|
| 1732 |
+
@st.fragment
|
| 1733 |
+
def show_global_trajectory_and_dl():
|
| 1734 |
+
from ase.io import read
|
| 1735 |
+
import py3Dmol
|
| 1736 |
+
|
| 1737 |
+
# Helper to convert atoms list to XYZ string for the single download file
|
| 1738 |
+
def atoms_list_to_xyz_string(atoms_list):
|
| 1739 |
+
xyz_str = ""
|
| 1740 |
+
for i, atoms in enumerate(atoms_list):
|
| 1741 |
+
xyz_str += f"{len(atoms)}\n"
|
| 1742 |
+
xyz_str += f"Step {i}, Energy = {atoms.get_potential_energy():.6f} eV\n"
|
| 1743 |
+
for atom in atoms:
|
| 1744 |
+
xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
|
| 1745 |
+
return xyz_str
|
| 1746 |
+
|
| 1747 |
+
if "global_traj_frames" not in st.session_state:
|
| 1748 |
+
if os.path.exists(traj_filename):
|
| 1749 |
+
try:
|
| 1750 |
+
# Read the trajectory we just created
|
| 1751 |
+
trajectory = read(traj_filename, index=":")
|
| 1752 |
+
st.session_state.global_traj_frames = trajectory
|
| 1753 |
+
st.session_state.global_traj_index = 0
|
| 1754 |
+
except Exception as e:
|
| 1755 |
+
st.error(f"Error reading trajectory: {e}")
|
| 1756 |
+
return
|
| 1757 |
+
else:
|
| 1758 |
+
st.warning("Trajectory file not generated.")
|
| 1759 |
+
return
|
| 1760 |
+
|
| 1761 |
+
trajectory = st.session_state.global_traj_frames
|
| 1762 |
+
|
| 1763 |
+
if not trajectory:
|
| 1764 |
+
st.warning("No steps recorded in trajectory.")
|
| 1765 |
+
return
|
| 1766 |
+
|
| 1767 |
+
index = st.session_state.global_traj_index
|
| 1768 |
+
|
| 1769 |
+
st.markdown("### Global Search Trajectory")
|
| 1770 |
+
st.write(f"Captured {len(trajectory)} hopping steps (Local Minima)")
|
| 1771 |
+
|
| 1772 |
+
# Navigation Controls
|
| 1773 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 1774 |
+
with col1:
|
| 1775 |
+
if st.button("⏮ First", key="g_first"):
|
| 1776 |
+
st.session_state.global_traj_index = 0
|
| 1777 |
+
with col2:
|
| 1778 |
+
if st.button("◀ Previous", key="g_prev") and index > 0:
|
| 1779 |
+
st.session_state.global_traj_index -= 1
|
| 1780 |
+
with col3:
|
| 1781 |
+
if st.button("Next ▶", key="g_next") and index < len(trajectory) - 1:
|
| 1782 |
+
st.session_state.global_traj_index += 1
|
| 1783 |
+
with col4:
|
| 1784 |
+
if st.button("Last ⏭", key="g_last"):
|
| 1785 |
+
st.session_state.global_traj_index = len(trajectory) - 1
|
| 1786 |
+
|
| 1787 |
+
# Display Visualization
|
| 1788 |
+
current_atoms = trajectory[st.session_state.global_traj_index]
|
| 1789 |
+
st.write(f"Frame {st.session_state.global_traj_index + 1}/{len(trajectory)} | E = {current_atoms.get_potential_energy():.4f} eV")
|
| 1790 |
+
|
| 1791 |
+
viz_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
|
| 1792 |
+
st.components.v1.html(viz_view._make_html(), width=400, height=400)
|
| 1793 |
+
|
| 1794 |
+
# Download Logic
|
| 1795 |
+
full_xyz_content = atoms_list_to_xyz_string(trajectory)
|
| 1796 |
+
|
| 1797 |
+
st.download_button(
|
| 1798 |
+
label="Download Trajectory (XYZ)",
|
| 1799 |
+
data=full_xyz_content,
|
| 1800 |
+
file_name="global_optimization_path.xyz",
|
| 1801 |
+
mime="chemical/x-xyz"
|
| 1802 |
+
)
|
| 1803 |
+
|
| 1804 |
+
# Separate Download for just the Best Structure (Last frame usually in BH, or sorted)
|
| 1805 |
+
# Often in BH, the last frame is the accepted state, but not necessarily the global min seen *ever*.
|
| 1806 |
+
# But usually, we want the lowest energy one.
|
| 1807 |
+
energies = [a.get_potential_energy() for a in trajectory]
|
| 1808 |
+
best_idx = np.argmin(energies)
|
| 1809 |
+
best_atoms = trajectory[best_idx]
|
| 1810 |
+
|
| 1811 |
+
# Create XYZ for single best
|
| 1812 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix=".xyz", delete=False) as tmp_best:
|
| 1813 |
+
write(tmp_best.name, best_atoms)
|
| 1814 |
+
tmp_best_name = tmp_best.name
|
| 1815 |
+
|
| 1816 |
+
with open(tmp_best_name, "r") as f:
|
| 1817 |
+
st.download_button(
|
| 1818 |
+
label=f"Download Best Structure (E={energies[best_idx]:.4f} eV)",
|
| 1819 |
+
data=f.read(),
|
| 1820 |
+
file_name="best_global_structure.xyz",
|
| 1821 |
+
mime="chemical/x-xyz"
|
| 1822 |
+
)
|
| 1823 |
+
os.unlink(tmp_best_name)
|
| 1824 |
+
|
| 1825 |
+
# Call the fragment function
|
| 1826 |
+
show_global_trajectory_and_dl()
|
| 1827 |
+
|
| 1828 |
+
# Cleanup main trajectory file after loading it into session state if desired,
|
| 1829 |
+
# though keeping it until session end is safer for re-reads.
|
| 1830 |
+
# os.unlink(traj_filename)
|
| 1831 |
elif task == "Vibrational Mode Analysis":
|
| 1832 |
# Conversion factors
|
| 1833 |
from ase.units import kB as kB_eVK, _Nav, J # ASE's constants
|
|
|
|
| 1848 |
vib.run()
|
| 1849 |
freqs = vib.get_frequencies()
|
| 1850 |
energies = vib.get_energies()
|
|
|
|
| 1851 |
|
|
|
|
| 1852 |
print('\n\n\n\n\n\n\n\n')
|
| 1853 |
# vib.get_hessian_2d()
|
| 1854 |
# st.write(vib.summary())
|
| 1855 |
# print('\n')
|
| 1856 |
# vib.tabulate()
|
|
|
|
| 1857 |
|
|
|
|
| 1858 |
freqs_cm = freqs
|
| 1859 |
freqs_eV = energies
|
|
|
|
| 1860 |
|
| 1861 |
# Classify frequencies
|
| 1862 |
mode_data = []
|
|
|
|
| 2022 |
- For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
|
| 2023 |
""")
|
| 2024 |
|
| 2025 |
+
|
| 2026 |
with st.expander('🔧 Tech Stack & System Information'):
|
| 2027 |
import platform
|
| 2028 |
import psutil
|