annabossler commited on
Commit
4ed9de7
·
verified ·
1 Parent(s): 3494b39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -51
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import tempfile
3
  import numpy as np
@@ -5,17 +6,26 @@ import gradio as gr
5
  from ase.io import read
6
  from ase.io.trajectory import Trajectory
7
 
8
- # ==== Intentar visor nativo como en UMA ====
 
 
 
9
  try:
10
  from gradio_molecule3d import Molecule3D
11
- HAVE_MOL3D = True
12
  except Exception:
13
  HAVE_MOL3D = False
14
 
15
-
16
  # ==== Fallback HTML con 3Dmol.js ====
17
  def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
18
- traj = Trajectory(traj_path)
 
 
 
 
 
 
 
19
  xyz_frames = []
20
  for atoms in traj:
21
  symbols = atoms.get_chemical_symbols()
@@ -25,32 +35,48 @@ def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
25
  parts.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
26
  xyz_frames.append("\n".join(parts))
27
 
 
 
 
28
  html = f"""
29
- <div id="viewer_md" style="width:{width}px; height:{height}px;"></div>
30
  <script src="https://3dmol.org/build/3Dmol-min.js"></script>
31
  <script>
32
  (function() {{
33
- var viewer = $3Dmol.createViewer("viewer_md", {{backgroundColor: 'white'}});
34
- var frames = {xyz_frames!r};
35
- var i = 0;
36
- function show(i) {{
37
- viewer.clear();
38
- viewer.addModel(frames[i], "xyz");
39
- viewer.setStyle({{}}, {{stick: {{}}}});
40
- viewer.zoomTo();
41
- viewer.render();
42
  }}
43
- if(frames.length>0) show(0);
44
- if(frames.length>1) setInterval(function(){{
45
- i=(i+1)%frames.length; show(i);
46
- }}, {int(interval_ms)});
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  }})();
48
  </script>
49
  """
50
  return html
51
 
52
-
53
- # ==== OrbMol SPE directo ====
54
  from orb_models.forcefield import pretrained
55
  from orb_models.forcefield.calculator import ORBCalculator
56
 
@@ -64,8 +90,10 @@ def _load_orbmol_calc():
64
  _MODEL_CALC = ORBCalculator(orbff, device="cpu")
65
  return _MODEL_CALC
66
 
67
-
68
  def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
 
 
 
69
  try:
70
  calc = _load_orbmol_calc()
71
  if not xyz_content or not xyz_content.strip():
@@ -79,30 +107,31 @@ def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
79
  atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
80
  atoms.calc = calc
81
 
82
- energy = atoms.get_potential_energy()
83
- forces = atoms.get_forces()
84
 
85
  lines = [f"Total Energy: {energy:.6f} eV", "", "Atomic Forces:"]
86
- for i, f in enumerate(forces):
87
- lines.append(f"Atom {i+1}: [{f[0]:.4f}, {f[1]:.4f}, {f[2]:.4f}] eV/Å")
88
  max_force = float(np.max(np.linalg.norm(forces, axis=1)))
89
  lines += ["", f"Max Force: {max_force:.4f} eV/Å"]
90
 
91
- try: os.unlink(xyz_file)
92
- except Exception: pass
 
 
93
 
94
  return "\n".join(lines), "Calculation completed with OrbMol"
95
  except Exception as e:
96
  return f"Error during calculation: {e}", "Error"
97
 
98
-
99
- # ==== Simulaciones (helpers locales) ====
100
  from simulation_scripts_orbmol import (
101
  run_md_simulation,
102
  run_relaxation_simulation,
103
  )
104
 
105
- # Convierte textbox XYZ a fichero temporal si es necesario (para ASE)
106
  def _string_looks_like_xyz(text: str) -> bool:
107
  try:
108
  first = (text or "").strip().splitlines()[0]
@@ -119,7 +148,6 @@ def _to_file_if_xyz(input_or_path: str):
119
  return tf.name, True
120
  return input_or_path, False
121
 
122
-
123
  # Wrappers: devuelven SIEMPRE (status, traj_path, log, script, explain, html_fallback)
124
  def md_wrapper(xyz_content, charge, spin, steps, tempK, timestep_fs, ensemble):
125
  tmp_created = False
@@ -130,7 +158,7 @@ def md_wrapper(xyz_content, charge, spin, steps, tempK, timestep_fs, ensemble):
130
  traj_path, log_text, script_text, explanation = run_md_simulation(
131
  path_or_str,
132
  int(steps),
133
- 20, # pre-relax
134
  float(timestep_fs),
135
  float(tempK),
136
  "NVT" if ensemble == "NVT" else "NVE",
@@ -139,7 +167,7 @@ def md_wrapper(xyz_content, charge, spin, steps, tempK, timestep_fs, ensemble):
139
  )
140
  status = f"MD completed: {int(steps)} steps at {int(tempK)} K ({ensemble})"
141
 
142
- html_value = "" if HAVE_MOL3D else traj_to_html(traj_path)
143
  return (status, traj_path, log_text, script_text, explanation, html_value)
144
 
145
  except Exception as e:
@@ -149,7 +177,6 @@ def md_wrapper(xyz_content, charge, spin, steps, tempK, timestep_fs, ensemble):
149
  try: os.remove(path_or_str)
150
  except Exception: pass
151
 
152
-
153
  def relax_wrapper(xyz_content, steps, fmax, charge, spin, relax_cell):
154
  tmp_created = False
155
  path_or_str = xyz_content
@@ -166,7 +193,7 @@ def relax_wrapper(xyz_content, steps, fmax, charge, spin, relax_cell):
166
  )
167
  status = f"Relaxation finished (≤ {int(steps)} steps, fmax={float(fmax)} eV/Å)"
168
 
169
- html_value = "" if HAVE_MOL3D else traj_to_html(traj_path)
170
  return (status, traj_path, log_text, script_text, explanation, html_value)
171
 
172
  except Exception as e:
@@ -176,7 +203,6 @@ def relax_wrapper(xyz_content, steps, fmax, charge, spin, relax_cell):
176
  try: os.remove(path_or_str)
177
  except Exception: pass
178
 
179
-
180
  # ==== Ejemplos ====
181
  examples = [
182
  ["""2
@@ -197,7 +223,6 @@ H -0.3630 -0.5133 0.8887
197
  H -0.3630 -0.5133 -0.8887""", 0, 1],
198
  ]
199
 
200
-
201
  # ==== UI ====
202
  with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
203
  with gr.Tabs():
@@ -209,8 +234,8 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
209
  gr.Markdown("Energías y fuerzas con **charge** y **spin multiplicity**.")
210
  xyz_input = gr.Textbox(label="XYZ Coordinates", lines=12, placeholder="Paste XYZ here…")
211
  with gr.Row():
212
- charge_input = gr.Slider(0, -10, 10, 1, label="Charge")
213
- spin_input = gr.Slider(1, 1, 11, 1, label="Spin Multiplicity")
214
  run_spe = gr.Button("Run OrbMol Prediction", variant="primary")
215
  with gr.Column(variant="panel", min_width=500):
216
  spe_out = gr.Textbox(label="Energy & Forces", lines=15, interactive=False)
@@ -223,15 +248,15 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
223
  with gr.Tab("Molecular Dynamics"):
224
  with gr.Row():
225
  with gr.Column(scale=2):
226
- xyz_md = gr.Textbox(label="XYZ Coordinates", lines=12, placeholder="Paste XYZ here…")
227
  with gr.Row():
228
- charge_md = gr.Slider(0, -10, 10, 1, label="Charge")
229
- spin_md = gr.Slider(1, 1, 11, 1, label="Spin Multiplicity")
230
  with gr.Row():
231
- steps_md = gr.Slider(100, 10, 2000, 10, label="Steps")
232
- temp_md = gr.Slider(300, 10, 1500, 10, label="Temperature (K)")
233
  with gr.Row():
234
- timestep_md = gr.Slider(1.0, 0.1, 5.0, 0.1, label="Timestep (fs)")
235
  ensemble_md = gr.Radio(["NVE", "NVT"], value="NVE", label="Ensemble")
236
  run_md_btn = gr.Button("Run MD Simulation", variant="primary")
237
 
@@ -258,7 +283,7 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
258
  md_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
259
  md_explain = gr.Markdown()
260
 
261
- # NOTA: no actualizamos md_viewer directamente; se refresca al cambiar md_traj
262
  run_md_btn.click(
263
  md_wrapper,
264
  inputs=[xyz_md, charge_md, spin_md, steps_md, temp_md, timestep_md, ensemble_md],
@@ -269,12 +294,12 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
269
  with gr.Tab("Relaxation / Optimization"):
270
  with gr.Row():
271
  with gr.Column(scale=2):
272
- xyz_rlx = gr.Textbox(label="XYZ Coordinates", lines=12, placeholder="Paste XYZ here…")
273
- steps_rlx = gr.Slider(300, 1, 2000, 1, label="Max Steps")
274
- fmax_rlx = gr.Slider(0.05, 0.001, 0.5, 0.001, label="Fmax (eV/Å)")
275
  with gr.Row():
276
- charge_rlx = gr.Slider(0, -10, 10, 1, label="Charge")
277
- spin_rlx = gr.Slider(1, 1, 11, 1, label="Spin")
278
  relax_cell = gr.Checkbox(False, label="Relax Unit Cell")
279
  run_rlx_btn = gr.Button("Run Optimization", variant="primary")
280
 
 
1
+ # app.py
2
  import os
3
  import tempfile
4
  import numpy as np
 
6
  from ase.io import read
7
  from ase.io.trajectory import Trajectory
8
 
9
+ # ----- Kill-switch para desactivar el visor si molesta (export VIS_DISABLED=1) -----
10
+ VIS_DISABLED = os.environ.get("VIS_DISABLED", "0") == "1"
11
+
12
+ # ==== Intentar visor nativo como en UMA (opcional) ====
13
  try:
14
  from gradio_molecule3d import Molecule3D
15
+ HAVE_MOL3D = True and not VIS_DISABLED
16
  except Exception:
17
  HAVE_MOL3D = False
18
 
 
19
  # ==== Fallback HTML con 3Dmol.js ====
20
  def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
21
+ """
22
+ Render de una trayectoria ASE (.traj) con 3Dmol.js (sin depender de Jupyter).
23
+ """
24
+ try:
25
+ traj = Trajectory(traj_path)
26
+ except Exception as e:
27
+ return f"<div style='color:#b00'>Error leyendo trayectoria: {e}</div>"
28
+
29
  xyz_frames = []
30
  for atoms in traj:
31
  symbols = atoms.get_chemical_symbols()
 
35
  parts.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
36
  xyz_frames.append("\n".join(parts))
37
 
38
+ if not xyz_frames:
39
+ return "<div style='color:#555'>Empty trajectory</div>"
40
+
41
  html = f"""
42
+ <div id="viewer_md" style="width:{width}px; height:{height}px; position:relative;"></div>
43
  <script src="https://3dmol.org/build/3Dmol-min.js"></script>
44
  <script>
45
  (function() {{
46
+ function ready(fn) {{
47
+ if (document.readyState !== 'loading') fn();
48
+ else document.addEventListener('DOMContentLoaded', fn);
 
 
 
 
 
 
49
  }}
50
+ ready(function() {{
51
+ var el = document.getElementById("viewer_md");
52
+ if (!el || typeof $3Dmol === "undefined") {{
53
+ el.innerHTML = '<div style="padding:8px;color:#b00">3Dmol.js no cargó. Reintenta o habilita CORS.</div>';
54
+ return;
55
+ }}
56
+ var viewer = $3Dmol.createViewer(el, {{backgroundColor: 'white'}});
57
+ var frames = {xyz_frames!r};
58
+ var i = 0;
59
+ function show(k) {{
60
+ viewer.clear();
61
+ viewer.addModel(frames[k], "xyz");
62
+ viewer.setStyle({{}}, {{stick: {{}}}});
63
+ viewer.zoomTo();
64
+ viewer.render();
65
+ }}
66
+ show(0);
67
+ if (frames.length > 1) {{
68
+ setInterval(function() {{
69
+ i = (i + 1) % frames.length;
70
+ show(i);
71
+ }}, {int(interval_ms)});
72
+ }}
73
+ }});
74
  }})();
75
  </script>
76
  """
77
  return html
78
 
79
+ # ==== OrbMol SPE directo (tu calculadora NO se toca) ====
 
80
  from orb_models.forcefield import pretrained
81
  from orb_models.forcefield.calculator import ORBCalculator
82
 
 
90
  _MODEL_CALC = ORBCalculator(orbff, device="cpu")
91
  return _MODEL_CALC
92
 
 
93
  def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
94
+ """
95
+ Single Point Energy + fuerzas. No escribe nada salvo un .xyz temporal.
96
+ """
97
  try:
98
  calc = _load_orbmol_calc()
99
  if not xyz_content or not xyz_content.strip():
 
107
  atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
108
  atoms.calc = calc
109
 
110
+ energy = atoms.get_potential_energy() # eV
111
+ forces = atoms.get_forces() # eV/Å
112
 
113
  lines = [f"Total Energy: {energy:.6f} eV", "", "Atomic Forces:"]
114
+ for i, fc in enumerate(forces):
115
+ lines.append(f"Atom {i+1}: [{fc[0]:.4f}, {fc[1]:.4f}, {fc[2]:.4f}] eV/Å")
116
  max_force = float(np.max(np.linalg.norm(forces, axis=1)))
117
  lines += ["", f"Max Force: {max_force:.4f} eV/Å"]
118
 
119
+ try:
120
+ os.unlink(xyz_file)
121
+ except Exception:
122
+ pass
123
 
124
  return "\n".join(lines), "Calculation completed with OrbMol"
125
  except Exception as e:
126
  return f"Error during calculation: {e}", "Error"
127
 
128
+ # ==== Simulaciones (helpers locales, ya los tienes en simulation_scripts_orbmol.py) ====
 
129
  from simulation_scripts_orbmol import (
130
  run_md_simulation,
131
  run_relaxation_simulation,
132
  )
133
 
134
+ # --- Utilidad: si el usuario pega XYZ en el textbox, guardamos a .xyz temporal ---
135
  def _string_looks_like_xyz(text: str) -> bool:
136
  try:
137
  first = (text or "").strip().splitlines()[0]
 
148
  return tf.name, True
149
  return input_or_path, False
150
 
 
151
  # Wrappers: devuelven SIEMPRE (status, traj_path, log, script, explain, html_fallback)
152
  def md_wrapper(xyz_content, charge, spin, steps, tempK, timestep_fs, ensemble):
153
  tmp_created = False
 
158
  traj_path, log_text, script_text, explanation = run_md_simulation(
159
  path_or_str,
160
  int(steps),
161
+ 20, # pre-relax fija
162
  float(timestep_fs),
163
  float(tempK),
164
  "NVT" if ensemble == "NVT" else "NVE",
 
167
  )
168
  status = f"MD completed: {int(steps)} steps at {int(tempK)} K ({ensemble})"
169
 
170
+ html_value = "" if (VIS_DISABLED or HAVE_MOL3D) else traj_to_html(traj_path)
171
  return (status, traj_path, log_text, script_text, explanation, html_value)
172
 
173
  except Exception as e:
 
177
  try: os.remove(path_or_str)
178
  except Exception: pass
179
 
 
180
  def relax_wrapper(xyz_content, steps, fmax, charge, spin, relax_cell):
181
  tmp_created = False
182
  path_or_str = xyz_content
 
193
  )
194
  status = f"Relaxation finished (≤ {int(steps)} steps, fmax={float(fmax)} eV/Å)"
195
 
196
+ html_value = "" if (VIS_DISABLED or HAVE_MOL3D) else traj_to_html(traj_path)
197
  return (status, traj_path, log_text, script_text, explanation, html_value)
198
 
199
  except Exception as e:
 
203
  try: os.remove(path_or_str)
204
  except Exception: pass
205
 
 
206
  # ==== Ejemplos ====
207
  examples = [
208
  ["""2
 
223
  H -0.3630 -0.5133 -0.8887""", 0, 1],
224
  ]
225
 
 
226
  # ==== UI ====
227
  with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
228
  with gr.Tabs():
 
234
  gr.Markdown("Energías y fuerzas con **charge** y **spin multiplicity**.")
235
  xyz_input = gr.Textbox(label="XYZ Coordinates", lines=12, placeholder="Paste XYZ here…")
236
  with gr.Row():
237
+ charge_input = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
238
+ spin_input = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin Multiplicity")
239
  run_spe = gr.Button("Run OrbMol Prediction", variant="primary")
240
  with gr.Column(variant="panel", min_width=500):
241
  spe_out = gr.Textbox(label="Energy & Forces", lines=15, interactive=False)
 
248
  with gr.Tab("Molecular Dynamics"):
249
  with gr.Row():
250
  with gr.Column(scale=2):
251
+ xyz_md = gr.Textbox(label="XYZ Coordinates or path (.xyz/.traj/.pdb/.cif)", lines=12, placeholder="Paste XYZ or path here…")
252
  with gr.Row():
253
+ charge_md = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
254
+ spin_md = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin Multiplicity")
255
  with gr.Row():
256
+ steps_md = gr.Slider(minimum=10, maximum=2000, value=100, step=10, label="Steps")
257
+ temp_md = gr.Slider(minimum=10, maximum=1500, value=300, step=10, label="Temperature (K)")
258
  with gr.Row():
259
+ timestep_md = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Timestep (fs)")
260
  ensemble_md = gr.Radio(["NVE", "NVT"], value="NVE", label="Ensemble")
261
  run_md_btn = gr.Button("Run MD Simulation", variant="primary")
262
 
 
283
  md_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
284
  md_explain = gr.Markdown()
285
 
286
+ # NOTA: el visor Molecule3D se refresca cuando cambia md_traj (no hay que conectarlo)
287
  run_md_btn.click(
288
  md_wrapper,
289
  inputs=[xyz_md, charge_md, spin_md, steps_md, temp_md, timestep_md, ensemble_md],
 
294
  with gr.Tab("Relaxation / Optimization"):
295
  with gr.Row():
296
  with gr.Column(scale=2):
297
+ xyz_rlx = gr.Textbox(label="XYZ Coordinates or path (.xyz/.traj/.pdb/.cif)", lines=12, placeholder="Paste XYZ or path here…")
298
+ steps_rlx = gr.Slider(minimum=1, maximum=2000, value=300, step=1, label="Max Steps")
299
+ fmax_rlx = gr.Slider(minimum=0.001, maximum=0.5, value=0.05, step=0.001, label="Fmax (eV/Å)")
300
  with gr.Row():
301
+ charge_rlx = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
302
+ spin_rlx = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin")
303
  relax_cell = gr.Checkbox(False, label="Relax Unit Cell")
304
  run_rlx_btn = gr.Button("Run Optimization", variant="primary")
305