ManasSharma07 commited on
Commit
da79a56
·
verified ·
1 Parent(s): fde6971

add materials project ID support

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +461 -39
src/streamlit_app.py CHANGED
@@ -4,13 +4,16 @@ import io
4
  import tempfile
5
  import torch
6
  # FOR CPU only mode
7
- # torch._dynamo.config.suppress_errors = True
8
  # Or disable compilation entirely
9
  # torch.backends.cudnn.enabled = False
10
  import numpy as np
11
  from ase import Atoms
12
  from ase.io import read, write
13
  from ase.optimize import BFGS, LBFGS, FIRE
 
 
 
14
  from ase.constraints import FixAtoms
15
  from ase.filters import FrechetCellFilter
16
  from ase.visualize import view
@@ -26,6 +29,9 @@ import subprocess
26
  import sys
27
  import pkg_resources
28
  from ase.vibrations import Vibrations
 
 
 
29
  import matplotlib.pyplot as plt
30
  mattersim_available = True
31
  if mattersim_available:
@@ -894,6 +900,7 @@ def check_atom_limit(atoms_obj, selected_model):
894
  MACE_MODELS = {
895
  "MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
896
  "MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
 
897
  "MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
898
  "MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
899
  "MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
@@ -952,59 +959,273 @@ def get_fairchem_model(selected_model_name, model_path_or_name, device, selected
952
  calc = FAIRChemCalculator(predictor, task_name="omol")
953
  return calc
954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
955
  st.sidebar.markdown("## Input Options")
956
- input_method = st.sidebar.radio("Choose Input Method:", ["Select Example", "Upload File", "Paste Content"])
957
- atoms = None
958
 
 
 
 
 
 
 
959
  if input_method == "Upload File":
960
  uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
 
 
961
  if uploaded_file:
962
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
963
- tmp_file.write(uploaded_file.getvalue())
964
- tmp_filepath = tmp_file.name
965
  try:
966
- atoms = read(tmp_filepath)
967
- st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
 
 
 
 
 
 
 
 
 
 
 
968
  except Exception as e:
969
  st.sidebar.error(f"Error loading file: {str(e)}")
 
 
970
  finally:
 
971
  if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
972
  os.unlink(tmp_filepath)
 
 
 
973
 
 
974
  elif input_method == "Select Example":
 
975
  example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
976
- if example_name:
 
 
977
  file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
978
  try:
979
- atoms = read(file_path)
980
- st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!")
 
 
981
  except Exception as e:
982
  st.sidebar.error(f"Error loading example: {str(e)}")
 
983
 
 
984
  elif input_method == "Paste Content":
985
  file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
986
- content = st.sidebar.text_area("Paste file content here:", height=200)
 
 
 
987
  if content:
988
- try:
989
- suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
990
- suffix = suffix_map.get(file_format, ".xyz")
991
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
992
- tmp_file.write(content.encode())
993
- tmp_filepath = tmp_file.name
994
- atoms = read(tmp_filepath)
995
- st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
996
- except Exception as e:
997
- st.sidebar.error(f"Error parsing content: {str(e)}")
998
- finally:
999
- if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
1000
- os.unlink(tmp_filepath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1001
 
1002
  if atoms is not None:
1003
  if not hasattr(atoms, 'info'):
1004
  atoms.info = {}
1005
  atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
1006
  atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
1007
-
 
 
 
 
1008
 
1009
  st.sidebar.markdown("## Model Selection")
1010
  if mattersim_available:
@@ -1044,7 +1265,6 @@ if model_type == "FairChem":
1044
  else:
1045
  atoms.info["charge"] = 0
1046
  atoms.info["spin"] = 1 # FairChem expects multiplicity
1047
-
1048
  if model_type == "ORB":
1049
  selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
1050
  model_path = ORB_MODELS[selected_model]
@@ -1076,19 +1296,43 @@ st.sidebar.markdown("## Task Selection")
1076
  task = st.sidebar.selectbox("Select Calculation Task:",
1077
  ["Energy Calculation",
1078
  "Energy + Forces Calculation",
1079
- "Atomization/Cohesive Energy", # New Task Added
1080
  "Geometry Optimization",
1081
  "Cell + Geometry Optimization",
 
1082
  "Vibrational Mode Analysis",
1083
  #"Phonons"
1084
- ])
1085
 
1086
  if "Optimization" in task:
 
 
 
 
1087
  st.sidebar.markdown("### Optimization Parameters")
1088
- max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps
1089
- fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax
1090
- optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1091
 
 
 
 
 
 
1092
 
1093
  if "Vibration" in task:
1094
  st.write("### Thermodynamic Quantities (Molecule Only)")
@@ -1131,7 +1375,7 @@ if atoms is not None:
1131
  st.markdown("### Selected Task")
1132
  st.write(f"**Task:** {task}")
1133
 
1134
- if "Optimization" in task:
1135
  st.write(f"**Max Steps:** {max_steps}")
1136
  st.write(f"**Convergence Threshold:** {fmax} eV/Å")
1137
  st.write(f"**Optimizer:** {optimizer_type}")
@@ -1270,7 +1514,7 @@ if atoms is not None:
1270
  results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV"
1271
  results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
1272
 
1273
- elif "Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
1274
  is_periodic = any(calc_atoms.pbc)
1275
  opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
1276
  # Create temporary trajectory file
@@ -1402,6 +1646,188 @@ if atoms is not None:
1402
  )
1403
 
1404
  show_trajectory_and_controls()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1405
  elif task == "Vibrational Mode Analysis":
1406
  # Conversion factors
1407
  from ase.units import kB as kB_eVK, _Nav, J # ASE's constants
@@ -1422,20 +1848,15 @@ if atoms is not None:
1422
  vib.run()
1423
  freqs = vib.get_frequencies()
1424
  energies = vib.get_energies()
1425
-
1426
 
1427
-
1428
  print('\n\n\n\n\n\n\n\n')
1429
  # vib.get_hessian_2d()
1430
  # st.write(vib.summary())
1431
  # print('\n')
1432
  # vib.tabulate()
1433
-
1434
 
1435
-
1436
  freqs_cm = freqs
1437
  freqs_eV = energies
1438
-
1439
 
1440
  # Classify frequencies
1441
  mode_data = []
@@ -1601,6 +2022,7 @@ with st.expander('ℹ️ About This App & Foundational MLIPs'):
1601
  - For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
1602
  """)
1603
 
 
1604
  with st.expander('🔧 Tech Stack & System Information'):
1605
  import platform
1606
  import psutil
 
4
  import tempfile
5
  import torch
6
  # FOR CPU only mode
7
+ #torch._dynamo.config.suppress_errors = True
8
  # Or disable compilation entirely
9
  # torch.backends.cudnn.enabled = False
10
  import numpy as np
11
  from ase import Atoms
12
  from ase.io import read, write
13
  from ase.optimize import BFGS, LBFGS, FIRE
14
+ from ase.optimize.basin import BasinHopping
15
+ from ase.optimize.minimahopping import MinimaHopping
16
+ from ase.units import kB
17
  from ase.constraints import FixAtoms
18
  from ase.filters import FrechetCellFilter
19
  from ase.visualize import view
 
29
  import sys
30
  import pkg_resources
31
  from ase.vibrations import Vibrations
32
+ from mp_api.client import MPRester
33
+ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
34
+ from pymatgen.io.ase import AseAtomsAdaptor
35
  import matplotlib.pyplot as plt
36
  mattersim_available = True
37
  if mattersim_available:
 
900
  MACE_MODELS = {
901
  "MACE MPA Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model",
902
  "MACE OMAT Medium": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model",
903
+ "MACE OMAT Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-small.model",
904
  "MACE MATPES r2SCAN Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model",
905
  "MACE MATPES PBE Medium": "https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model",
906
  "MACE MP 0a Small": "https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0/2023-12-10-mace-128-L0_energy_epoch-249.model",
 
959
  calc = FAIRChemCalculator(predictor, task_name="omol")
960
  return calc
961
 
962
+ # st.sidebar.markdown("## Input Options")
963
+ # input_method = st.sidebar.radio("Choose Input Method:",
964
+ # ["Select Example", "Upload File", "Paste Content", "Materials Project ID"])
965
+ # atoms = None
966
+
967
+ # if input_method == "Upload File":
968
+ # uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
969
+ # if uploaded_file:
970
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
971
+ # tmp_file.write(uploaded_file.getvalue())
972
+ # tmp_filepath = tmp_file.name
973
+ # try:
974
+ # atoms = read(tmp_filepath)
975
+ # st.sidebar.success(f"Successfully loaded structure with {len(atoms)} atoms!")
976
+ # except Exception as e:
977
+ # st.sidebar.error(f"Error loading file: {str(e)}")
978
+ # finally:
979
+ # if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
980
+ # os.unlink(tmp_filepath)
981
+
982
+ # elif input_method == "Select Example":
983
+ # example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
984
+ # if example_name:
985
+ # file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
986
+ # try:
987
+ # atoms = read(file_path)
988
+ # st.sidebar.success(f"Loaded {example_name} with {len(atoms)} atoms!")
989
+ # except Exception as e:
990
+ # st.sidebar.error(f"Error loading example: {str(e)}")
991
+
992
+ # elif input_method == "Paste Content":
993
+ # file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
994
+ # content = st.sidebar.text_area("Paste file content here:", height=200)
995
+ # if content:
996
+ # try:
997
+ # suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
998
+ # suffix = suffix_map.get(file_format, ".xyz")
999
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
1000
+ # tmp_file.write(content.encode())
1001
+ # tmp_filepath = tmp_file.name
1002
+ # atoms = read(tmp_filepath)
1003
+ # st.sidebar.success(f"Successfully parsed structure with {len(atoms)} atoms!")
1004
+ # except Exception as e:
1005
+ # st.sidebar.error(f"Error parsing content: {str(e)}")
1006
+ # finally:
1007
+ # if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
1008
+ # os.unlink(tmp_filepath)
1009
+ # elif input_method == "Materials Project ID":
1010
+ # # st.sidebar.info("Requires an API Key from [Materials Project](https://next-gen.materialsproject.org/api)")
1011
+
1012
+ # # 1. Get API Key (You can also set this as an env variable)
1013
+ # # mp_api_key = st.sidebar.text_input("MP API Key", type="password")
1014
+ # mp_api_key = os.getenv("MP_API_KEY")
1015
+ # # 2. Get Material ID
1016
+ # material_id = st.sidebar.text_input("Enter Material ID:", value="mp-149")
1017
+
1018
+ # # 3. Choose Cell Type
1019
+ # cell_type = st.sidebar.radio("Unit Cell Type:", ['Primitive Cell', 'Conventional Unit Cell'])
1020
+
1021
+ # # 1. Initialize the session state variable if it doesn't exist
1022
+ # if "atoms" not in st.session_state:
1023
+ # st.session_state.atoms = None
1024
+
1025
+ # if st.sidebar.button("Fetch Structure"):
1026
+ # if not mp_api_key:
1027
+ # st.sidebar.error("Please enter your Materials Project API Key.")
1028
+ # elif not material_id:
1029
+ # st.sidebar.error("Please enter a Material ID.")
1030
+ # else:
1031
+ # try:
1032
+ # with st.spinner(f"Fetching {material_id}..."):
1033
+ # # Connect to MP
1034
+ # with MPRester(mp_api_key) as mpr:
1035
+ # # Get Pymatgen Structure
1036
+ # pmg_structure = mpr.get_structure_by_material_id(material_id)
1037
+
1038
+ # # Handle Symmetry (Primitive vs Conventional)
1039
+ # analyzer = SpacegroupAnalyzer(pmg_structure)
1040
+
1041
+ # if cell_type == 'Conventional Unit Cell':
1042
+ # final_structure = analyzer.get_conventional_standard_structure()
1043
+ # else:
1044
+ # final_structure = analyzer.get_primitive_standard_structure()
1045
+
1046
+ # # Convert Pymatgen Structure -> ASE Atoms
1047
+ # # 2. Store the result in session_state instead of a local variable
1048
+ # st.session_state.atoms = AseAtomsAdaptor.get_atoms(final_structure)
1049
+
1050
+ # st.sidebar.success(f"Loaded {material_id} ({cell_type}) with {len(st.session_state.atoms)} atoms.")
1051
+
1052
+ # except Exception as e:
1053
+ # st.sidebar.error(f"Error fetching data: {str(e)}")
1054
+ # st.session_state.atoms = None
1055
+
1056
+ # # 3. Retrieve the data from session_state for use in the rest of your app
1057
+ # atoms = st.session_state.atoms
1058
+
1059
+ # # --- Rest of your processing code ---
1060
+ # # if atoms is not None:
1061
+ # # st.write(f"System is ready. Formula: {atoms.get_chemical_formula()}")
1062
+ # # else:
1063
+ # # st.info("Please fetch a structure to begin.")
1064
+ # if atoms is not None:
1065
+ # if not hasattr(atoms, 'info'):
1066
+ # atoms.info = {}
1067
+ # atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
1068
+ # atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
1069
+
1070
+ # --- INITIALIZATION (Must be run first) ---
1071
+ if "atoms" not in st.session_state:
1072
+ st.session_state.atoms = None
1073
+
1074
+ # Reset atoms state if input method changes, to prevent using old data
1075
+ # Use a key to track the currently active input method
1076
+ if 'current_input_method' not in st.session_state:
1077
+ st.session_state.current_input_method = "Select Example"
1078
+
1079
  st.sidebar.markdown("## Input Options")
1080
+ input_method = st.sidebar.radio("Choose Input Method:",
1081
+ ["Select Example", "Upload File", "Paste Content", "Materials Project ID"])
1082
 
1083
+ # If the input method changes, clear the loaded structure
1084
+ if input_method != st.session_state.current_input_method:
1085
+ st.session_state.atoms = None
1086
+ st.session_state.current_input_method = input_method
1087
+
1088
+ # --- UPLOAD FILE ---
1089
  if input_method == "Upload File":
1090
  uploaded_file = st.sidebar.file_uploader("Upload structure file", type=["xyz", "cif", "POSCAR", "mol", "tmol", "vasp", "sdf", "CONTCAR"])
1091
+
1092
+ # Load immediately upon file upload/change (no button needed)
1093
  if uploaded_file:
 
 
 
1094
  try:
1095
+ # Check if this file content has already been loaded to prevent redundant temp file operations
1096
+ if 'uploaded_file_hash' not in st.session_state or st.session_state.uploaded_file_hash != uploaded_file.name:
1097
+
1098
+ # Use tempfile to handle the uploaded file content
1099
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
1100
+ tmp_file.write(uploaded_file.getvalue())
1101
+ tmp_filepath = tmp_file.name
1102
+
1103
+ atoms_to_store = read(tmp_filepath)
1104
+ st.session_state.atoms = atoms_to_store
1105
+ st.session_state.uploaded_file_hash = uploaded_file.name # Track the loaded file
1106
+ st.sidebar.success(f"Successfully loaded structure with {len(atoms_to_store)} atoms!")
1107
+
1108
  except Exception as e:
1109
  st.sidebar.error(f"Error loading file: {str(e)}")
1110
+ st.session_state.atoms = None
1111
+ st.session_state.uploaded_file_hash = None # Clear hash on failure
1112
  finally:
1113
+ # Clean up the temporary file
1114
  if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
1115
  os.unlink(tmp_filepath)
1116
+ else:
1117
+ # Clear structure if file uploader is empty
1118
+ st.session_state.atoms = None
1119
 
1120
+ # --- SELECT EXAMPLE ---
1121
  elif input_method == "Select Example":
1122
+ # Load immediately upon selection change (no button needed)
1123
  example_name = st.sidebar.selectbox("Select Example Structure:", list(SAMPLE_STRUCTURES.keys()))
1124
+
1125
+ # Only load if a valid example is selected and it's different from the current state
1126
+ if example_name and (st.session_state.atoms is None or st.session_state.atoms.info.get('source_name') != example_name):
1127
  file_path = os.path.join(SAMPLE_DIR, SAMPLE_STRUCTURES[example_name])
1128
  try:
1129
+ atoms_to_store = read(file_path)
1130
+ atoms_to_store.info['source_name'] = example_name # Add a tag for tracking
1131
+ st.session_state.atoms = atoms_to_store
1132
+ st.sidebar.success(f"Loaded {example_name} with {len(atoms_to_store)} atoms!")
1133
  except Exception as e:
1134
  st.sidebar.error(f"Error loading example: {str(e)}")
1135
+ st.session_state.atoms = None
1136
 
1137
+ # --- PASTE CONTENT ---
1138
  elif input_method == "Paste Content":
1139
  file_format = st.sidebar.selectbox("File Format:", ["XYZ", "CIF", "extXYZ", "POSCAR (VASP)", "Turbomole", "MOL"])
1140
+ content = st.sidebar.text_area("Paste file content here:", height=200, key="paste_content_input")
1141
+
1142
+ # Load immediately upon content change (no button needed)
1143
+ # Check if content is present and is different from the last successfully parsed content
1144
  if content:
1145
+ # Simple check to avoid parsing on every single character change
1146
+ if 'last_parsed_content' not in st.session_state or st.session_state.last_parsed_content != content:
1147
+
1148
+ try:
1149
+ suffix_map = {"XYZ": ".xyz", "CIF": ".cif", "extXYZ": ".extxyz", "POSCAR (VASP)": ".vasp", "Turbomole": ".tmol", "MOL": ".mol"}
1150
+ suffix = suffix_map.get(file_format, ".xyz")
1151
+
1152
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
1153
+ tmp_file.write(content.encode())
1154
+ tmp_filepath = tmp_file.name
1155
+
1156
+ atoms_to_store = read(tmp_filepath)
1157
+ st.session_state.atoms = atoms_to_store
1158
+ st.session_state.last_parsed_content = content # Track the parsed content
1159
+ st.sidebar.success(f"Successfully parsed structure with {len(atoms_to_store)} atoms!")
1160
+ except Exception as e:
1161
+ st.sidebar.error(f"Error parsing content: {str(e)}")
1162
+ st.session_state.atoms = None
1163
+ st.session_state.last_parsed_content = None
1164
+ finally:
1165
+ if 'tmp_filepath' in locals() and os.path.exists(tmp_filepath):
1166
+ os.unlink(tmp_filepath)
1167
+ else:
1168
+ # Clear structure if text area is empty
1169
+ st.session_state.atoms = None
1170
+
1171
+
1172
+ # --- MATERIALS PROJECT ID ---
1173
+ elif input_method == "Materials Project ID":
1174
+ mp_api_key = os.getenv("MP_API_KEY")
1175
+ material_id = st.sidebar.text_input("Enter Material ID:", value="mp-149", key="mp_id_input")
1176
+ cell_type = st.sidebar.radio("Unit Cell Type:", ['Primitive Cell', 'Conventional Unit Cell'], key="cell_type_radio")
1177
+
1178
+ # Reactive Loading (No button needed)
1179
+ # Check for valid inputs and if the current material_id/cell_type is different from the loaded one
1180
+ if mp_api_key and material_id:
1181
+
1182
+ # Simple tracking to avoid API call if nothing has changed
1183
+ current_mp_key = f"{material_id}_{cell_type}"
1184
+ if 'last_fetched_mp_key' not in st.session_state or st.session_state.last_fetched_mp_key != current_mp_key:
1185
+
1186
+ try:
1187
+ with st.spinner(f"Fetching {material_id}..."):
1188
+ with MPRester(mp_api_key) as mpr:
1189
+ pmg_structure = mpr.get_structure_by_material_id(material_id)
1190
+ analyzer = SpacegroupAnalyzer(pmg_structure)
1191
+
1192
+ if cell_type == 'Conventional Unit Cell':
1193
+ final_structure = analyzer.get_conventional_standard_structure()
1194
+ else:
1195
+ final_structure = analyzer.get_primitive_standard_structure()
1196
+
1197
+ atoms_to_store = AseAtomsAdaptor.get_atoms(final_structure)
1198
+ st.session_state.atoms = atoms_to_store
1199
+ st.session_state.last_fetched_mp_key = current_mp_key # Update tracking key
1200
+ st.sidebar.success(f"Loaded {material_id} ({cell_type}) with {len(st.session_state.atoms)} atoms.")
1201
+
1202
+ except Exception as e:
1203
+ st.sidebar.error(f"Error fetching data: {str(e)}")
1204
+ st.session_state.atoms = None
1205
+ st.session_state.last_fetched_mp_key = None # Clear key on failure
1206
+
1207
+ # Handle error messages when inputs are missing
1208
+ elif not mp_api_key:
1209
+ st.sidebar.error("Please set your Materials Project API Key (MP_API_KEY environment variable).")
1210
+ elif not material_id:
1211
+ st.sidebar.error("Please enter a Material ID.")
1212
+
1213
+ # ----------------------------------------------------
1214
+ # --- FINAL STRUCTURE RETRIEVAL (The persistent structure) ---
1215
+ # ----------------------------------------------------
1216
+ # This is the single source of truth for the rest of your app
1217
+ atoms = st.session_state.atoms
1218
 
1219
  if atoms is not None:
1220
  if not hasattr(atoms, 'info'):
1221
  atoms.info = {}
1222
  atoms.info["charge"] = atoms.info.get("charge", 0) # Default charge
1223
  atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
1224
+
1225
+ # Display confirmation in the main area (optional, helps the user confirm what's loaded)
1226
+ st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
1227
+
1228
+ # You can now add your processing buttons here, which will consistently use the 'atoms' variable.
1229
 
1230
  st.sidebar.markdown("## Model Selection")
1231
  if mattersim_available:
 
1265
  else:
1266
  atoms.info["charge"] = 0
1267
  atoms.info["spin"] = 1 # FairChem expects multiplicity
 
1268
  if model_type == "ORB":
1269
  selected_model = st.sidebar.selectbox("Select ORB Model:", list(ORB_MODELS.keys()))
1270
  model_path = ORB_MODELS[selected_model]
 
1296
  task = st.sidebar.selectbox("Select Calculation Task:",
1297
  ["Energy Calculation",
1298
  "Energy + Forces Calculation",
1299
+ "Atomization/Cohesive Energy",
1300
  "Geometry Optimization",
1301
  "Cell + Geometry Optimization",
1302
+ #"Global Optimization",
1303
  "Vibrational Mode Analysis",
1304
  #"Phonons"
1305
+ ])
1306
 
1307
  if "Optimization" in task:
1308
+ # st.sidebar.markdown("### Optimization Parameters")
1309
+ # max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1) # Increased max_steps
1310
+ # fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f") # Adjusted default fmax
1311
+ # optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1) # Renamed to optimizer_type
1312
  st.sidebar.markdown("### Optimization Parameters")
1313
+
1314
+ # 1. Configuration for GLOBAL Optimization
1315
+ if task == "Global Optimization":
1316
+ global_method = st.sidebar.selectbox("Method:", ["Basin Hopping", "Minima Hopping"])
1317
+
1318
+ # Common parameters
1319
+ temperature_K = st.sidebar.number_input("Temperature (K):", min_value=10.0, max_value=2000.0, value=300.0, step=10.0)
1320
+ global_steps = st.sidebar.number_input("Search Steps:", min_value=10, max_value=500, value=50, step=10)
1321
+ # Basin Hopping specific
1322
+ if global_method == "Basin Hopping":
1323
+ dr_amp = st.sidebar.number_input("Displacement Amplitude (Å):", min_value=0.1, max_value=2.0, value=0.7, step=0.1)
1324
+ fmax_local = st.sidebar.number_input("Local Relaxation Threshold (eV/Å):", value=0.05, format="%.3f")
1325
+
1326
+ # Minima Hopping specific
1327
+ elif global_method == "Minima Hopping":
1328
+ st.sidebar.caption("Minima Hopping automates threshold adjustments to escape local minima.")
1329
+ fmax_local = st.sidebar.number_input("Local Relaxation Threshold (eV/Å):", value=0.05, format="%.3f")
1330
 
1331
+ # 2. Configuration for LOCAL/CELL Optimization
1332
+ else:
1333
+ max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1)
1334
+ fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f")
1335
+ optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1)
1336
 
1337
  if "Vibration" in task:
1338
  st.write("### Thermodynamic Quantities (Molecule Only)")
 
1375
  st.markdown("### Selected Task")
1376
  st.write(f"**Task:** {task}")
1377
 
1378
+ if "Geometry Optimization" in task:
1379
  st.write(f"**Max Steps:** {max_steps}")
1380
  st.write(f"**Convergence Threshold:** {fmax} eV/Å")
1381
  st.write(f"**Optimizer:** {optimizer_type}")
 
1514
  results["System Energy ($E_{system}$)"] = f"{E_system:.6f} eV"
1515
  results["Total Isolated Atom Energy ($\sum E_{atoms}$)"] = f"{E_isolated_atoms_total:.6f} eV"
1516
 
1517
+ elif "Geometry Optimization" in task: # Handles both Geometry and Cell+Geometry Opt
1518
  is_periodic = any(calc_atoms.pbc)
1519
  opt_atoms_obj = FrechetCellFilter(calc_atoms) if task == "Cell + Geometry Optimization" else calc_atoms
1520
  # Create temporary trajectory file
 
1646
  )
1647
 
1648
  show_trajectory_and_controls()
1649
+ elif task == "Global Optimization":
1650
+ st.info(f"Starting Global Optimization using {global_method}...")
1651
+
1652
+ # Create temporary trajectory file to store the "hopping" steps
1653
+ traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
1654
+
1655
+ # Container for live updates
1656
+ log_container = st.empty()
1657
+ global_min_energy = 0
1658
+
1659
+ def global_log(opt_instance):
1660
+ """Helper to log global optimization steps."""
1661
+ global global_min_energy
1662
+ current_e = opt_instance.atoms.get_potential_energy()
1663
+ # For BasinHopping, nsteps is available. For others, we might need a counter.
1664
+ step = getattr(opt_instance, 'nsteps', 'N/A')
1665
+ log_container.write(f"Global Step: {step} | Energy: {current_e:.6f} eV")
1666
+ if current_e < global_min_energy:
1667
+ global_min_energy = current_e
1668
+
1669
+ if global_method == "Basin Hopping":
1670
+ # Basin Hopping requires Temperature in eV (kB * T)
1671
+ kT = temperature_K * kB
1672
+
1673
+ # Create the wrapper for the hack needed to enforce the optimization to stop when it reaches a certain number of steps
1674
+ class LimitedLBFGS(LBFGS):
1675
+ def run(self, fmax=0.05, steps=None):
1676
+ # 'steps' here overrides whatever BasinHopping tries to do.
1677
+ # Set your desired max local steps (e.g., 200)
1678
+ return super().run(fmax=fmax, steps=100)
1679
+ # Initialize Basin Hopping with the trajectory file
1680
+ bh = BasinHopping(calc_atoms,
1681
+ temperature=kT,
1682
+ dr=dr_amp,
1683
+ optimizer=LimitedLBFGS,
1684
+ fmax=fmax_local,
1685
+ trajectory=traj_filename) # Log steps to file automatically
1686
+
1687
+ # Attach the live logger
1688
+ bh.attach(lambda: global_log(bh), interval=1)
1689
+
1690
+ # Run the optimization
1691
+ bh.run(global_steps)
1692
+
1693
+ results["Global Minimum Energy"] = f"{global_min_energy:.6f} eV"
1694
+ results["Steps Taken"] = global_steps
1695
+ results["Converged"] = "N/A (Global Search)"
1696
+
1697
+ elif global_method == "Minima Hopping":
1698
+ # Minima Hopping manages its own internal optimizers and doesn't accept a 'trajectory'
1699
+ # file argument in the same way BasinHopping does in __init__.
1700
+ opt = MinimaHopping(calc_atoms,
1701
+ T0=temperature_K,
1702
+ fmax=fmax_local,
1703
+ optimizer=LBFGS)
1704
+
1705
+ # We run it. Live logging is harder here without subclassing,
1706
+ # so we rely on the final output for the trajectory.
1707
+ opt(totalsteps=global_steps)
1708
+
1709
+ results["Current Energy"] = f"{calc_atoms.get_potential_energy():.6f} eV"
1710
+
1711
+ # Post-processing: MinimaHopping stores visited minima in an internal list usually.
1712
+ # We explicitly write the found minima to the trajectory file so the visualizer below works.
1713
+ # Note: opt.minima is a list of Atoms objects found during the hop.
1714
+ if hasattr(opt, 'minima'):
1715
+ from ase.io import write
1716
+ write(traj_filename, opt.minima)
1717
+ else:
1718
+ # Fallback if specific version doesn't store list, just save final
1719
+ write(traj_filename, calc_atoms)
1720
+
1721
+ st.success("Global Optimization Complete!")
1722
+
1723
+ st.markdown("### Results")
1724
+ for key, value in results.items():
1725
+ st.write(f"**{key}:** {value}")
1726
+
1727
+ # --- Visualization and Downloading (Fragmented) ---
1728
+
1729
+ # 1. Clean up the temp file path for reading
1730
+ # We define the visualizer function using @st.fragment to prevent full re-runs
1731
+
1732
+ @st.fragment
1733
+ def show_global_trajectory_and_dl():
1734
+ from ase.io import read
1735
+ import py3Dmol
1736
+
1737
+ # Helper to convert atoms list to XYZ string for the single download file
1738
+ def atoms_list_to_xyz_string(atoms_list):
1739
+ xyz_str = ""
1740
+ for i, atoms in enumerate(atoms_list):
1741
+ xyz_str += f"{len(atoms)}\n"
1742
+ xyz_str += f"Step {i}, Energy = {atoms.get_potential_energy():.6f} eV\n"
1743
+ for atom in atoms:
1744
+ xyz_str += f"{atom.symbol} {atom.position[0]:.6f} {atom.position[1]:.6f} {atom.position[2]:.6f}\n"
1745
+ return xyz_str
1746
+
1747
+ if "global_traj_frames" not in st.session_state:
1748
+ if os.path.exists(traj_filename):
1749
+ try:
1750
+ # Read the trajectory we just created
1751
+ trajectory = read(traj_filename, index=":")
1752
+ st.session_state.global_traj_frames = trajectory
1753
+ st.session_state.global_traj_index = 0
1754
+ except Exception as e:
1755
+ st.error(f"Error reading trajectory: {e}")
1756
+ return
1757
+ else:
1758
+ st.warning("Trajectory file not generated.")
1759
+ return
1760
+
1761
+ trajectory = st.session_state.global_traj_frames
1762
+
1763
+ if not trajectory:
1764
+ st.warning("No steps recorded in trajectory.")
1765
+ return
1766
+
1767
+ index = st.session_state.global_traj_index
1768
+
1769
+ st.markdown("### Global Search Trajectory")
1770
+ st.write(f"Captured {len(trajectory)} hopping steps (Local Minima)")
1771
+
1772
+ # Navigation Controls
1773
+ col1, col2, col3, col4 = st.columns(4)
1774
+ with col1:
1775
+ if st.button("⏮ First", key="g_first"):
1776
+ st.session_state.global_traj_index = 0
1777
+ with col2:
1778
+ if st.button("◀ Previous", key="g_prev") and index > 0:
1779
+ st.session_state.global_traj_index -= 1
1780
+ with col3:
1781
+ if st.button("Next ▶", key="g_next") and index < len(trajectory) - 1:
1782
+ st.session_state.global_traj_index += 1
1783
+ with col4:
1784
+ if st.button("Last ⏭", key="g_last"):
1785
+ st.session_state.global_traj_index = len(trajectory) - 1
1786
+
1787
+ # Display Visualization
1788
+ current_atoms = trajectory[st.session_state.global_traj_index]
1789
+ st.write(f"Frame {st.session_state.global_traj_index + 1}/{len(trajectory)} | E = {current_atoms.get_potential_energy():.4f} eV")
1790
+
1791
+ viz_view = get_structure_viz2(current_atoms, style=viz_style, show_unit_cell=True, width=400, height=400)
1792
+ st.components.v1.html(viz_view._make_html(), width=400, height=400)
1793
+
1794
+ # Download Logic
1795
+ full_xyz_content = atoms_list_to_xyz_string(trajectory)
1796
+
1797
+ st.download_button(
1798
+ label="Download Trajectory (XYZ)",
1799
+ data=full_xyz_content,
1800
+ file_name="global_optimization_path.xyz",
1801
+ mime="chemical/x-xyz"
1802
+ )
1803
+
1804
+ # Separate Download for just the Best Structure (Last frame usually in BH, or sorted)
1805
+ # Often in BH, the last frame is the accepted state, but not necessarily the global min seen *ever*.
1806
+ # But usually, we want the lowest energy one.
1807
+ energies = [a.get_potential_energy() for a in trajectory]
1808
+ best_idx = np.argmin(energies)
1809
+ best_atoms = trajectory[best_idx]
1810
+
1811
+ # Create XYZ for single best
1812
+ with tempfile.NamedTemporaryFile(mode='w', suffix=".xyz", delete=False) as tmp_best:
1813
+ write(tmp_best.name, best_atoms)
1814
+ tmp_best_name = tmp_best.name
1815
+
1816
+ with open(tmp_best_name, "r") as f:
1817
+ st.download_button(
1818
+ label=f"Download Best Structure (E={energies[best_idx]:.4f} eV)",
1819
+ data=f.read(),
1820
+ file_name="best_global_structure.xyz",
1821
+ mime="chemical/x-xyz"
1822
+ )
1823
+ os.unlink(tmp_best_name)
1824
+
1825
+ # Call the fragment function
1826
+ show_global_trajectory_and_dl()
1827
+
1828
+ # Cleanup main trajectory file after loading it into session state if desired,
1829
+ # though keeping it until session end is safer for re-reads.
1830
+ # os.unlink(traj_filename)
1831
  elif task == "Vibrational Mode Analysis":
1832
  # Conversion factors
1833
  from ase.units import kB as kB_eVK, _Nav, J # ASE's constants
 
1848
  vib.run()
1849
  freqs = vib.get_frequencies()
1850
  energies = vib.get_energies()
 
1851
 
 
1852
  print('\n\n\n\n\n\n\n\n')
1853
  # vib.get_hessian_2d()
1854
  # st.write(vib.summary())
1855
  # print('\n')
1856
  # vib.tabulate()
 
1857
 
 
1858
  freqs_cm = freqs
1859
  freqs_eV = energies
 
1860
 
1861
  # Classify frequencies
1862
  mode_data = []
 
2022
  - For **FairChem models**, isolated atom energies are based on pre-tabulated reference values (provided in a YAML-like structure within the app). Ensure the selected FairChem task type (`omol`, `omat`, etc. for UMA models) or model type (ESEN models use `omol` references) aligns with the system and has the necessary elemental references.
2023
  """)
2024
 
2025
+
2026
  with st.expander('🔧 Tech Stack & System Information'):
2027
  import platform
2028
  import psutil