Spaces:
Running
Running
Added new optimizers
Browse files- src/streamlit_app.py +21 -5
src/streamlit_app.py
CHANGED
|
@@ -4,13 +4,14 @@ import io
|
|
| 4 |
import tempfile
|
| 5 |
import torch
|
| 6 |
# FOR CPU only mode
|
| 7 |
-
|
| 8 |
# Or disable compilation entirely
|
| 9 |
# torch.backends.cudnn.enabled = False
|
| 10 |
import numpy as np
|
| 11 |
from ase import Atoms
|
| 12 |
from ase.io import read, write
|
| 13 |
-
from ase.optimize import BFGS, LBFGS, FIRE
|
|
|
|
| 14 |
from ase.optimize.basin import BasinHopping
|
| 15 |
from ase.optimize.minimahopping import MinimaHopping
|
| 16 |
from ase.units import kB
|
|
@@ -1258,7 +1259,7 @@ if atoms is not None:
|
|
| 1258 |
atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
|
| 1259 |
|
| 1260 |
# Display confirmation in the main area (optional, helps the user confirm what's loaded)
|
| 1261 |
-
st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
|
| 1262 |
|
| 1263 |
st.sidebar.markdown("## Model Selection")
|
| 1264 |
if mattersim_available:
|
|
@@ -1365,7 +1366,8 @@ if "Optimization" in task:
|
|
| 1365 |
else:
|
| 1366 |
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1)
|
| 1367 |
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f")
|
| 1368 |
-
optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1)
|
|
|
|
| 1369 |
|
| 1370 |
if "Vibration" in task:
|
| 1371 |
st.write("### Thermodynamic Quantities (Molecule Only)")
|
|
@@ -1555,10 +1557,24 @@ if atoms is not None:
|
|
| 1555 |
traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
|
| 1556 |
if optimizer_type == "BFGS":
|
| 1557 |
opt = BFGS(opt_atoms_obj, trajectory=traj_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1558 |
elif optimizer_type == "LBFGS":
|
| 1559 |
opt = LBFGS(opt_atoms_obj, trajectory=traj_filename)
|
| 1560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1561 |
opt = FIRE(opt_atoms_obj, trajectory=traj_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1562 |
|
| 1563 |
# opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly
|
| 1564 |
opt.attach(lambda: streamlit_log(opt), interval=1)
|
|
|
|
| 4 |
import tempfile
|
| 5 |
import torch
|
| 6 |
# FOR CPU only mode
|
| 7 |
+
torch._dynamo.config.suppress_errors = True
|
| 8 |
# Or disable compilation entirely
|
| 9 |
# torch.backends.cudnn.enabled = False
|
| 10 |
import numpy as np
|
| 11 |
from ase import Atoms
|
| 12 |
from ase.io import read, write
|
| 13 |
+
from ase.optimize import BFGS, LBFGS, FIRE, LBFGSLineSearch, BFGSLineSearch, GPMin, MDMin
|
| 14 |
+
from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG
|
| 15 |
from ase.optimize.basin import BasinHopping
|
| 16 |
from ase.optimize.minimahopping import MinimaHopping
|
| 17 |
from ase.units import kB
|
|
|
|
| 1259 |
atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
|
| 1260 |
|
| 1261 |
# Display confirmation in the main area (optional, helps the user confirm what's loaded)
|
| 1262 |
+
# st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
|
| 1263 |
|
| 1264 |
st.sidebar.markdown("## Model Selection")
|
| 1265 |
if mattersim_available:
|
|
|
|
| 1366 |
else:
|
| 1367 |
max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1)
|
| 1368 |
fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f")
|
| 1369 |
+
# optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1)
|
| 1370 |
+
optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "BFGSLineSearch", "LBFGS", "LBFGSLineSearch", "FIRE", "GPMin", "MDMin"], index=1)
|
| 1371 |
|
| 1372 |
if "Vibration" in task:
|
| 1373 |
st.write("### Thermodynamic Quantities (Molecule Only)")
|
|
|
|
| 1557 |
traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
|
| 1558 |
if optimizer_type == "BFGS":
|
| 1559 |
opt = BFGS(opt_atoms_obj, trajectory=traj_filename)
|
| 1560 |
+
|
| 1561 |
+
elif optimizer_type == "BFGSLineSearch":
|
| 1562 |
+
opt = BFGSLineSearch(opt_atoms_obj, trajectory=traj_filename)
|
| 1563 |
+
|
| 1564 |
elif optimizer_type == "LBFGS":
|
| 1565 |
opt = LBFGS(opt_atoms_obj, trajectory=traj_filename)
|
| 1566 |
+
|
| 1567 |
+
elif optimizer_type == "LBFGSLineSearch":
|
| 1568 |
+
opt = LBFGSLineSearch(opt_atoms_obj, trajectory=traj_filename)
|
| 1569 |
+
|
| 1570 |
+
elif optimizer_type == "FIRE":
|
| 1571 |
opt = FIRE(opt_atoms_obj, trajectory=traj_filename)
|
| 1572 |
+
|
| 1573 |
+
elif optimizer_type == "GPMin":
|
| 1574 |
+
opt = GPMin(opt_atoms_obj, trajectory=traj_filename)
|
| 1575 |
+
|
| 1576 |
+
elif optimizer_type == "MDMin":
|
| 1577 |
+
opt = MDMin(opt_atoms_obj, trajectory=traj_filename)
|
| 1578 |
|
| 1579 |
# opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly
|
| 1580 |
opt.attach(lambda: streamlit_log(opt), interval=1)
|