ManasSharma07 commited on
Commit
3b09e3e
·
verified ·
1 Parent(s): db88eb3

Added new optimizers

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +21 -5
src/streamlit_app.py CHANGED
@@ -4,13 +4,14 @@ import io
4
  import tempfile
5
  import torch
6
  # FOR CPU only mode
7
- # torch._dynamo.config.suppress_errors = True
8
  # Or disable compilation entirely
9
  # torch.backends.cudnn.enabled = False
10
  import numpy as np
11
  from ase import Atoms
12
  from ase.io import read, write
13
- from ase.optimize import BFGS, LBFGS, FIRE
 
14
  from ase.optimize.basin import BasinHopping
15
  from ase.optimize.minimahopping import MinimaHopping
16
  from ase.units import kB
@@ -1258,7 +1259,7 @@ if atoms is not None:
1258
  atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
1259
 
1260
  # Display confirmation in the main area (optional, helps the user confirm what's loaded)
1261
- st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
1262
 
1263
  st.sidebar.markdown("## Model Selection")
1264
  if mattersim_available:
@@ -1365,7 +1366,8 @@ if "Optimization" in task:
1365
  else:
1366
  max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1)
1367
  fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f")
1368
- optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1)
 
1369
 
1370
  if "Vibration" in task:
1371
  st.write("### Thermodynamic Quantities (Molecule Only)")
@@ -1555,10 +1557,24 @@ if atoms is not None:
1555
  traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
1556
  if optimizer_type == "BFGS":
1557
  opt = BFGS(opt_atoms_obj, trajectory=traj_filename)
 
 
 
 
1558
  elif optimizer_type == "LBFGS":
1559
  opt = LBFGS(opt_atoms_obj, trajectory=traj_filename)
1560
- else: # FIRE
 
 
 
 
1561
  opt = FIRE(opt_atoms_obj, trajectory=traj_filename)
 
 
 
 
 
 
1562
 
1563
  # opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly
1564
  opt.attach(lambda: streamlit_log(opt), interval=1)
 
4
  import tempfile
5
  import torch
6
  # FOR CPU only mode
7
+ torch._dynamo.config.suppress_errors = True
8
  # Or disable compilation entirely
9
  # torch.backends.cudnn.enabled = False
10
  import numpy as np
11
  from ase import Atoms
12
  from ase.io import read, write
13
+ from ase.optimize import BFGS, LBFGS, FIRE, LBFGSLineSearch, BFGSLineSearch, GPMin, MDMin
14
+ from ase.optimize.sciopt import SciPyFminBFGS, SciPyFminCG
15
  from ase.optimize.basin import BasinHopping
16
  from ase.optimize.minimahopping import MinimaHopping
17
  from ase.units import kB
 
1259
  atoms.info["spin"] = atoms.info.get("spin", 1) # Default spin (usually 2S for ASE, model might want 2S+1)
1260
 
1261
  # Display confirmation in the main area (optional, helps the user confirm what's loaded)
1262
+ # st.markdown(f"**Loaded Structure:** {atoms.get_chemical_formula()} ({len(atoms)} atoms)")
1263
 
1264
  st.sidebar.markdown("## Model Selection")
1265
  if mattersim_available:
 
1366
  else:
1367
  max_steps = st.sidebar.slider("Maximum Steps:", min_value=10, max_value=200, value=50, step=1)
1368
  fmax = st.sidebar.slider("Convergence Threshold (eV/Å):", min_value=0.001, max_value=0.1, value=0.01, step=0.001, format="%.3f")
1369
+ # optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "LBFGS", "FIRE"], index=1)
1370
+ optimizer_type = st.sidebar.selectbox("Optimizer:", ["BFGS", "BFGSLineSearch", "LBFGS", "LBFGSLineSearch", "FIRE", "GPMin", "MDMin"], index=1)
1371
 
1372
  if "Vibration" in task:
1373
  st.write("### Thermodynamic Quantities (Molecule Only)")
 
1557
  traj_filename = tempfile.NamedTemporaryFile(delete=False, suffix=".traj").name
1558
  if optimizer_type == "BFGS":
1559
  opt = BFGS(opt_atoms_obj, trajectory=traj_filename)
1560
+
1561
+ elif optimizer_type == "BFGSLineSearch":
1562
+ opt = BFGSLineSearch(opt_atoms_obj, trajectory=traj_filename)
1563
+
1564
  elif optimizer_type == "LBFGS":
1565
  opt = LBFGS(opt_atoms_obj, trajectory=traj_filename)
1566
+
1567
+ elif optimizer_type == "LBFGSLineSearch":
1568
+ opt = LBFGSLineSearch(opt_atoms_obj, trajectory=traj_filename)
1569
+
1570
+ elif optimizer_type == "FIRE":
1571
  opt = FIRE(opt_atoms_obj, trajectory=traj_filename)
1572
+
1573
+ elif optimizer_type == "GPMin":
1574
+ opt = GPMin(opt_atoms_obj, trajectory=traj_filename)
1575
+
1576
+ elif optimizer_type == "MDMin":
1577
+ opt = MDMin(opt_atoms_obj, trajectory=traj_filename)
1578
 
1579
  # opt.attach(streamlit_log, interval=1) # Removed lambda for simplicity if streamlit_log is defined correctly
1580
  opt.attach(lambda: streamlit_log(opt), interval=1)