annabossler commited on
Commit
9223603
·
verified ·
1 Parent(s): bdf6036

Update simulation_scripts_orbmol.py

Browse files
Files changed (1) hide show
  1. simulation_scripts_orbmol.py +77 -12
simulation_scripts_orbmol.py CHANGED
@@ -1,3 +1,4 @@
 
1
  """
2
  Minimal FAIRChem-like simulation helpers for OrbMol (local inference).
3
 
@@ -8,6 +9,7 @@ from simulation_scripts_orbmol import (
8
  run_md_simulation,
9
  run_relaxation_simulation,
10
  atoms_to_xyz,
 
11
  )
12
  """
13
 
@@ -15,6 +17,7 @@ from __future__ import annotations
15
  import os
16
  import tempfile
17
  from pathlib import Path
 
18
 
19
  import numpy as np
20
  import ase
@@ -44,8 +47,6 @@ def load_orbmol_model(device: str = "cpu", precision: str = "float32-high") -> O
44
  """
45
  global _model_calc
46
  if _model_calc is None:
47
- # NOTE: orb_v3_conservative_inf_omat is the conservative Orb family entry point
48
- # used in OrbMol blog; works for molecules (aperiodic).
49
  orbff = pretrained.orb_v3_conservative_inf_omat(
50
  device=device,
51
  precision=precision,
@@ -66,15 +67,63 @@ def atoms_to_xyz(atoms: ase.Atoms) -> str:
66
  lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
67
  return "\n".join(lines)
68
 
 
 
 
 
 
 
 
 
69
  def _center_atoms(atoms: ase.Atoms) -> None:
70
  """
71
  Center coordinates for nicer visualization (no effect on energies).
72
  """
73
  atoms.positions -= atoms.get_center_of_mass()
74
- if atoms.cell is not None and atoms.cell.any():
75
  cell_center = atoms.get_cell().sum(axis=0) / 2
76
  atoms.positions += cell_center
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def validate_ase_atoms(structure_file: str | Path, max_atoms: int = 5000) -> ase.Atoms:
79
  """
80
  Read & validate an ASE-compatible file; ensures uniform PBC and non-empty.
@@ -105,7 +154,7 @@ def validate_ase_atoms(structure_file: str | Path, max_atoms: int = 5000) -> ase
105
  # Molecular Dynamics (MD)
106
  # -----------------------------
107
  def run_md_simulation(
108
- structure_file: str | Path,
109
  num_steps: int,
110
  num_prerelax_steps: int,
111
  md_timestep: float, # fs
@@ -117,14 +166,19 @@ def run_md_simulation(
117
  ) -> tuple[str, str, str, str]:
118
  """
119
  Run short MD using OrbMol.
 
120
  Returns: (traj_path, md_log_text, reproduction_script, explanation)
121
  """
122
  traj_path = None
123
  md_log_path = None
124
  atoms = None
 
 
125
 
126
  try:
127
- atoms = validate_ase_atoms(structure_file)
 
 
128
 
129
  # Attach the calculator
130
  calc = load_orbmol_model()
@@ -140,7 +194,7 @@ def run_md_simulation(
140
 
141
  # Quick pre-relaxation to remove bad contacts
142
  opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path)
143
- if num_prerelax_steps > 0:
144
  opt.run(fmax=0.05, steps=int(num_prerelax_steps))
145
 
146
  # Initialize velocities (double T after relaxation as in UMA demo)
@@ -220,22 +274,24 @@ dyn.run({int(num_steps)})
220
  return traj_path, md_log_text, reproduction_script, explanation
221
 
222
  except Exception as e:
223
- # Bubble up a clean error
224
  raise RuntimeError(f"Error running MD: {e}") from e
225
  finally:
226
  # Detach calculator to free memory
227
  if atoms is not None and getattr(atoms, "calc", None) is not None:
228
  atoms.calc = None
229
- if md_log_path and not os.path.exists(md_log_path):
230
- md_log_path = None
231
- # (No deletion of traj/log here; the UI needs the files.)
 
 
 
232
 
233
 
234
  # -----------------------------
235
  # Geometry optimization
236
  # -----------------------------
237
  def run_relaxation_simulation(
238
- structure_file: str | Path,
239
  num_steps: int,
240
  fmax: float, # eV/Å
241
  total_charge: int,
@@ -245,14 +301,18 @@ def run_relaxation_simulation(
245
  ) -> tuple[str, str, str, str]:
246
  """
247
  Run LBFGS relaxation (with optional cell relaxation).
 
248
  Returns: (traj_path, log_text, reproduction_script, explanation)
249
  """
250
  traj_path = None
251
  opt_log_path = None
252
  atoms = None
 
 
253
 
254
  try:
255
- atoms = validate_ase_atoms(structure_file)
 
256
 
257
  calc = load_orbmol_model()
258
  atoms.info["charge"] = int(total_charge)
@@ -303,3 +363,8 @@ optimizer.run(fmax={float(fmax)}, steps={int(num_steps)})
303
  finally:
304
  if atoms is not None and getattr(atoms, "calc", None) is not None:
305
  atoms.calc = None
 
 
 
 
 
 
1
+ # simulation_scripts_orbmol.py
2
  """
3
  Minimal FAIRChem-like simulation helpers for OrbMol (local inference).
4
 
 
9
  run_md_simulation,
10
  run_relaxation_simulation,
11
  atoms_to_xyz,
12
+ last_frame_xyz_from_traj,
13
  )
14
  """
15
 
 
17
  import os
18
  import tempfile
19
  from pathlib import Path
20
+ from typing import Tuple
21
 
22
  import numpy as np
23
  import ase
 
47
  """
48
  global _model_calc
49
  if _model_calc is None:
 
 
50
  orbff = pretrained.orb_v3_conservative_inf_omat(
51
  device=device,
52
  precision=precision,
 
67
  lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
68
  return "\n".join(lines)
69
 
70
+ def last_frame_xyz_from_traj(traj_path: str | Path) -> str:
71
+ """
72
+ Read the last frame of an ASE .traj and return it as XYZ string.
73
+ """
74
+ tr = Trajectory(str(traj_path))
75
+ last = tr[-1]
76
+ return atoms_to_xyz(last)
77
+
78
  def _center_atoms(atoms: ase.Atoms) -> None:
79
  """
80
  Center coordinates for nicer visualization (no effect on energies).
81
  """
82
  atoms.positions -= atoms.get_center_of_mass()
83
+ if atoms.cell is not None and np.array(atoms.cell).any():
84
  cell_center = atoms.get_cell().sum(axis=0) / 2
85
  atoms.positions += cell_center
86
 
87
+ def _string_looks_like_xyz(text: str) -> bool:
88
+ """
89
+ Heurística simple para detectar si un input es un XYZ en texto.
90
+ """
91
+ if not isinstance(text, str):
92
+ return False
93
+ lines = [l for l in text.strip().splitlines() if l.strip()]
94
+ if len(lines) < 2:
95
+ return False
96
+ # primera línea: número de átomos
97
+ try:
98
+ _ = int(lines[0].split()[0])
99
+ return True
100
+ except Exception:
101
+ return False
102
+
103
+ def _materialize_input_to_file(input_or_path: str | Path) -> Tuple[str, bool]:
104
+ """
105
+ Devuelve (file_path, is_temp). Si input es un string XYZ, lo guarda a un .xyz temporal.
106
+ Si es una ruta existente, la devuelve tal cual.
107
+ """
108
+ # Caso: dict de Gradio File {'path': ...}
109
+ if isinstance(input_or_path, dict) and "path" in input_or_path:
110
+ p = input_or_path["path"]
111
+ return p, False
112
+
113
+ # Caso: Path o ruta existente
114
+ if isinstance(input_or_path, (str, Path)) and os.path.exists(str(input_or_path)):
115
+ return str(input_or_path), False
116
+
117
+ # Caso: probablemente es un string XYZ
118
+ if isinstance(input_or_path, str) and _string_looks_like_xyz(input_or_path):
119
+ tf = tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False)
120
+ tf.write(input_or_path)
121
+ tf.flush()
122
+ tf.close()
123
+ return tf.name, True
124
+
125
+ raise ValueError("Input must be an existing file path or a valid XYZ string.")
126
+
127
  def validate_ase_atoms(structure_file: str | Path, max_atoms: int = 5000) -> ase.Atoms:
128
  """
129
  Read & validate an ASE-compatible file; ensures uniform PBC and non-empty.
 
154
  # Molecular Dynamics (MD)
155
  # -----------------------------
156
  def run_md_simulation(
157
+ structure_file_or_xyz: str | Path,
158
  num_steps: int,
159
  num_prerelax_steps: int,
160
  md_timestep: float, # fs
 
166
  ) -> tuple[str, str, str, str]:
167
  """
168
  Run short MD using OrbMol.
169
+ Accepts a path or an XYZ string.
170
  Returns: (traj_path, md_log_text, reproduction_script, explanation)
171
  """
172
  traj_path = None
173
  md_log_path = None
174
  atoms = None
175
+ realized_path = None
176
+ is_temp = False
177
 
178
  try:
179
+ # Permitir tanto ruta como string XYZ
180
+ realized_path, is_temp = _materialize_input_to_file(structure_file_or_xyz)
181
+ atoms = validate_ase_atoms(realized_path)
182
 
183
  # Attach the calculator
184
  calc = load_orbmol_model()
 
194
 
195
  # Quick pre-relaxation to remove bad contacts
196
  opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path)
197
+ if int(num_prerelax_steps) > 0:
198
  opt.run(fmax=0.05, steps=int(num_prerelax_steps))
199
 
200
  # Initialize velocities (double T after relaxation as in UMA demo)
 
274
  return traj_path, md_log_text, reproduction_script, explanation
275
 
276
  except Exception as e:
 
277
  raise RuntimeError(f"Error running MD: {e}") from e
278
  finally:
279
  # Detach calculator to free memory
280
  if atoms is not None and getattr(atoms, "calc", None) is not None:
281
  atoms.calc = None
282
+ # Limpieza del .xyz temporal si lo generamos nosotros
283
+ if is_temp and realized_path and os.path.exists(realized_path):
284
+ try:
285
+ os.remove(realized_path)
286
+ except Exception:
287
+ pass
288
 
289
 
290
  # -----------------------------
291
  # Geometry optimization
292
  # -----------------------------
293
  def run_relaxation_simulation(
294
+ structure_file_or_xyz: str | Path,
295
  num_steps: int,
296
  fmax: float, # eV/Å
297
  total_charge: int,
 
301
  ) -> tuple[str, str, str, str]:
302
  """
303
  Run LBFGS relaxation (with optional cell relaxation).
304
+ Accepts a path or an XYZ string.
305
  Returns: (traj_path, log_text, reproduction_script, explanation)
306
  """
307
  traj_path = None
308
  opt_log_path = None
309
  atoms = None
310
+ realized_path = None
311
+ is_temp = False
312
 
313
  try:
314
+ realized_path, is_temp = _materialize_input_to_file(structure_file_or_xyz)
315
+ atoms = validate_ase_atoms(realized_path)
316
 
317
  calc = load_orbmol_model()
318
  atoms.info["charge"] = int(total_charge)
 
363
  finally:
364
  if atoms is not None and getattr(atoms, "calc", None) is not None:
365
  atoms.calc = None
366
+ if is_temp and realized_path and os.path.exists(realized_path):
367
+ try:
368
+ os.remove(realized_path)
369
+ except Exception:
370
+ pass