isayev commited on
Commit
7185599
·
verified ·
1 Parent(s): 996f171

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +3 -2
  2. app.py +839 -288
  3. requirements.txt +1 -0
README.md CHANGED
@@ -17,5 +17,6 @@ tags:
17
 
18
  # AIMNet2 Interactive Demo
19
 
20
- Fast neural network interatomic potential for molecular property prediction.
21
- Supports SMILES, XYZ, and PDB input formats.
 
 
17
 
18
  # AIMNet2 Interactive Demo
19
 
20
+ Neural network potential for molecular property prediction.
21
+ Supports energy, forces, charges, geometry optimization, and vibrational frequencies.
22
+ 3D visualization with charge coloring.
app.py CHANGED
@@ -1,24 +1,33 @@
1
- """AIMNet2 Gradio Demo Space.
2
 
3
- Interactive neural network potential: energy, forces, charges, 3D visualization.
4
- Atoms colored by predicted partial charge (red=negative, blue=positive).
5
  """
6
 
7
- import io
8
- import os
9
 
 
 
 
 
 
 
 
10
  import numpy as np
 
11
  import torch
 
12
 
13
  torch.set_num_threads(2)
14
 
15
- import gradio as gr
16
-
17
  # ---------------------------------------------------------------------------
18
  # Constants
19
  # ---------------------------------------------------------------------------
20
  MAX_ATOMS = 200
 
21
  MAX_ATOMS_HESSIAN = 50
 
 
22
 
23
  HARTREE_TO_EV = 27.211386024367243
24
  EV_TO_KCAL = 23.06054783
@@ -30,128 +39,77 @@ ELEMENT_SYMBOLS = {
30
  }
31
  SYMBOL_TO_NUM = {v: k for k, v in ELEMENT_SYMBOLS.items()}
32
 
33
- # Jmol-style CPK colors for default view
34
- CPK_COLORS = {
35
- "H": "#FFFFFF", "B": "#FFB5B5", "C": "#909090", "N": "#3050F8",
36
- "O": "#FF0D0D", "F": "#90E050", "Si": "#F0C8A0", "P": "#FF8000",
37
- "S": "#FFFF30", "Cl": "#1FF01F", "As": "#BD80E3", "Se": "#FFA100",
38
- "Br": "#A62929", "Pd": "#006985", "I": "#940094",
39
  }
40
 
 
 
 
 
 
 
 
 
41
  # ---------------------------------------------------------------------------
42
- # Lazy model loader
43
  # ---------------------------------------------------------------------------
44
- CALC = None
45
 
46
 
47
- def get_calc():
48
- global CALC
49
- if CALC is None:
 
50
  from aimnet.calculators import AIMNet2Calculator
51
- from aimnet.calculators.aimnet2ase import AIMNet2ASE
 
52
 
53
- base_calc = AIMNet2Calculator("isayevlab/aimnet2-wb97m-d3", device="cpu")
54
- CALC = AIMNet2ASE(base_calc)
55
- return CALC
56
 
 
 
 
 
57
 
58
- # ---------------------------------------------------------------------------
59
- # 3D Visualization
60
- # ---------------------------------------------------------------------------
61
 
62
- def charge_to_rgb(q, qmin=-0.8, qmax=0.8):
63
- """Map charge to color: red (negative) -> white (neutral) -> blue (positive)."""
64
- t = np.clip((q - qmin) / (qmax - qmin), 0, 1)
65
- if t < 0.5:
66
- # red -> white
67
- s = t * 2
68
- r, g, b = 1.0, s, s
69
- else:
70
- # white -> blue
71
- s = (t - 0.5) * 2
72
- r, g, b = 1.0 - s, 1.0 - s, 1.0
73
- return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
74
-
75
-
76
- def build_3dmol_html(coords, numbers, charges=None, width=500, height=400):
77
- """Generate HTML with embedded 3Dmol.js viewer, atoms colored by charge."""
78
- # Build XYZ string
79
- n = len(numbers)
80
- xyz_lines = [str(n), "AIMNet2 prediction"]
81
- for i in range(n):
82
- sym = ELEMENT_SYMBOLS.get(int(numbers[i]), "X")
83
- x, y, z = coords[i]
84
- xyz_lines.append(f"{sym} {x:.6f} {y:.6f} {z:.6f}")
85
- xyz_str = "\\n".join(xyz_lines)
86
-
87
- # Build per-atom color assignments
88
- if charges is not None:
89
- qmin = float(np.min(charges))
90
- qmax = float(np.max(charges))
91
- # Symmetric range for better visualization
92
- qlim = max(abs(qmin), abs(qmax), 0.3)
93
- color_js = ""
94
- for i in range(n):
95
- c = charge_to_rgb(float(charges[i]), -qlim, qlim)
96
- color_js += f'viewer.getModel().setAtomStyle({{index:{i}}},{{stick:{{radius:0.15}},sphere:{{scale:0.25,color:"{c}"}}}});'
97
- color_mode_label = "colored by charge (red=&minus;, blue=+)"
98
- else:
99
- color_js = ""
100
- color_mode_label = "CPK colors"
101
-
102
- html = f"""
103
- <div id="viewer-container" style="position:relative;width:{width}px;height:{height}px;margin:0 auto;">
104
- <div id="mol-viewer" style="width:{width}px;height:{height}px;"></div>
105
- <div style="position:absolute;bottom:4px;right:8px;font-size:11px;color:#888;">{color_mode_label}</div>
106
- </div>
107
- <script src="https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js"></script>
108
- <script>
109
- (function() {{
110
- let viewer = $3Dmol.createViewer(document.getElementById("mol-viewer"), {{
111
- backgroundColor: "white"
112
- }});
113
- viewer.addModel("{xyz_str}", "xyz");
114
- viewer.setStyle({{}}, {{stick:{{radius:0.15}}, sphere:{{scale:0.25,colorscheme:"Jmol"}}}});
115
- {color_js}
116
- viewer.zoomTo();
117
- viewer.render();
118
- }})();
119
- </script>
120
- """
121
- return html
122
 
123
 
124
  # ---------------------------------------------------------------------------
125
  # Parsers
126
  # ---------------------------------------------------------------------------
127
 
128
- def parse_smiles(smiles: str):
129
- """Return (coords, numbers, formal_charge)."""
130
  from rdkit import Chem
131
  from rdkit.Chem import AllChem
132
 
133
  mol = Chem.MolFromSmiles(smiles.strip())
134
  if mol is None:
135
- raise ValueError(f"RDKit could not parse SMILES: {smiles!r}")
136
  formal_charge = Chem.GetFormalCharge(mol)
137
  mol = Chem.AddHs(mol)
138
- result = AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
139
- if result == -1:
140
- raise ValueError("Failed to generate 3D coordinates.")
141
  AllChem.MMFFOptimizeMolecule(mol)
142
  conf = mol.GetConformer()
143
- coords = np.array([conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())], dtype=np.float64)
144
- numbers = np.array([atom.GetAtomicNum() for atom in mol.GetAtoms()], dtype=np.int64)
145
  return coords, numbers, formal_charge
146
 
147
 
148
- def parse_xyz(xyz_text: str):
149
- """Return (coords, numbers)."""
150
- lines = [l.strip() for l in xyz_text.strip().splitlines()]
151
- start = 0
152
- if lines and lines[0].isdigit():
153
- start = 2
154
- coords_list, numbers_list = [], []
155
  for line in lines[start:]:
156
  if not line:
157
  continue
@@ -161,17 +119,17 @@ def parse_xyz(xyz_text: str):
161
  sym = parts[0].capitalize()
162
  if sym not in SYMBOL_TO_NUM:
163
  raise ValueError(f"Unknown element: {sym!r}")
164
- numbers_list.append(SYMBOL_TO_NUM[sym])
165
  coords_list.append([float(parts[1]), float(parts[2]), float(parts[3])])
166
  if not coords_list:
167
  raise ValueError("No atoms found in XYZ input.")
168
- return np.array(coords_list, dtype=np.float64), np.array(numbers_list, dtype=np.int64)
169
 
170
 
171
- def parse_pdb(pdb_text: str):
172
- """Return (coords, numbers)."""
173
- coords_list, numbers_list = [], []
174
- for line in pdb_text.splitlines():
175
  if not line.startswith(("ATOM", "HETATM")):
176
  continue
177
  try:
@@ -184,264 +142,857 @@ def parse_pdb(pdb_text: str):
184
  elem = elem.capitalize()
185
  if elem not in SYMBOL_TO_NUM:
186
  raise ValueError(f"Unknown element in PDB: {elem!r}")
187
- numbers_list.append(SYMBOL_TO_NUM[elem])
188
  coords_list.append([x, y, z])
189
  if not coords_list:
190
  raise ValueError("No ATOM/HETATM records found.")
191
- return np.array(coords_list, dtype=np.float64), np.array(numbers_list, dtype=np.int64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
 
194
  # ---------------------------------------------------------------------------
195
- # Main prediction function — returns (markdown, html_viewer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  # ---------------------------------------------------------------------------
197
 
198
  def predict(input_text, input_format, charge, compute_forces, compute_hessian):
199
- """Run AIMNet2 and return (results_markdown, 3d_viewer_html)."""
200
  charge = int(charge)
 
201
 
202
- # --- Parse ---
203
- smiles_warning = ""
204
  try:
205
- if input_format == "SMILES":
206
- coords, numbers, fc = parse_smiles(input_text)
207
- if fc != charge:
208
- smiles_warning = (
209
- f"\n> **Warning:** SMILES formal charge ({fc:+d}) != "
210
- f"supplied charge ({charge:+d}). Using supplied charge.\n"
211
- )
212
- elif input_format == "XYZ":
213
- coords, numbers = parse_xyz(input_text)
214
- elif input_format == "PDB":
215
- coords, numbers = parse_pdb(input_text)
216
- else:
217
- return f"**Error:** Unknown format: {input_format}", ""
218
  except Exception as e:
219
- return f"**Parse error:** {e}", ""
220
 
221
- n_atoms = len(numbers)
222
- if n_atoms > MAX_ATOMS:
223
- return f"**Error:** {n_atoms} atoms exceeds limit of {MAX_ATOMS}.", ""
224
- if compute_hessian and n_atoms > MAX_ATOMS_HESSIAN:
225
- return f"**Error:** Hessian limited to {MAX_ATOMS_HESSIAN} atoms ({n_atoms} given).", ""
226
 
227
- unsupported = sorted({int(z) for z in numbers} - set(ELEMENT_SYMBOLS.keys()))
 
228
  if unsupported:
229
- return f"**Error:** Unsupported elements: {unsupported}", ""
 
 
 
 
 
 
 
 
 
 
230
 
231
- # --- Calculate ---
232
  try:
233
- calc = get_calc()
234
- calc.set_charge(charge)
235
-
236
  from ase import Atoms
237
  symbols = [ELEMENT_SYMBOLS[int(z)] for z in numbers]
238
  atoms = Atoms(symbols=symbols, positions=coords)
239
- atoms.calc = calc
240
 
241
  atoms.get_potential_energy()
242
- energy_ev = float(atoms.calc.results["energy"])
243
- charges_arr = atoms.calc.results.get("charges")
 
 
 
244
 
245
  forces_arr = None
246
  if compute_forces:
247
  atoms.get_forces()
248
- forces_arr = atoms.calc.results["forces"]
249
 
250
  hessian_arr = None
 
 
251
  if compute_hessian:
252
  data = {"coord": coords, "numbers": numbers, "charge": float(charge)}
253
- hess_results = calc.base_calc(data, hessian=True)
254
- hessian_arr = hess_results["hessian"].detach().cpu().numpy()
 
255
 
256
  except Exception as e:
257
  import traceback
258
- return f"**Calculation error:** {e}\n```\n{traceback.format_exc()}\n```", ""
259
 
260
- # --- 3D Viewer (colored by charge) ---
261
- viewer_html = build_3dmol_html(coords, numbers, charges_arr)
262
 
263
- # --- Format results ---
264
  energy_kcal = energy_ev * EV_TO_KCAL
265
  energy_ha = energy_ev / HARTREE_TO_EV
 
 
 
 
 
 
 
 
 
 
266
 
267
- lines = []
268
- lines.append("## AIMNet2 Results\n")
269
- if smiles_warning:
270
- lines.append(smiles_warning)
271
- lines.append(f"**Atoms:** {n_atoms} | **Charge:** {charge:+d}\n")
272
-
273
- # Energy table
274
- lines.append("### Energy\n")
275
- lines.append("| Unit | Value |")
276
- lines.append("|------|------:|")
277
- lines.append(f"| eV | {energy_ev:.6f} |")
278
- lines.append(f"| kcal/mol | {energy_kcal:.4f} |")
279
- lines.append(f"| Hartree | {energy_ha:.8f} |")
280
- lines.append("")
281
-
282
- # Charges table
283
  if charges_arr is not None:
284
- lines.append("### Partial Charges (e)\n")
285
- lines.append("| # | Elem | Charge |")
286
- lines.append("|--:|:----:|-------:|")
287
  for i, (z, q) in enumerate(zip(numbers, charges_arr)):
288
- sym = ELEMENT_SYMBOLS.get(int(z), f"Z{z}")
289
- lines.append(f"| {i+1} | {sym} | {q:+.4f} |")
290
- lines.append("")
291
- lines.append(f"*Sum: {float(np.sum(charges_arr)):+.4f} e*\n")
292
 
293
- # Forces
294
  if forces_arr is not None:
295
  max_f = float(np.max(np.linalg.norm(forces_arr, axis=1)))
296
- rms_f = float(np.sqrt(np.mean(forces_arr ** 2)))
297
- lines.append("### Forces (eV/A)\n")
298
- lines.append("| Metric | Value |")
299
- lines.append("|--------|------:|")
300
- lines.append(f"| Max |F| | {max_f:.6f} |")
301
- lines.append(f"| RMS | {rms_f:.6f} |")
302
- lines.append("")
303
  if input_format == "SMILES":
304
- lines.append("> *Geometry from MMFF, not AIMNet2-optimized. Non-zero forces expected.*\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- # Hessian
307
- if hessian_arr is not None:
308
- # Compute vibrational frequencies from Hessian
309
- lines.append("### Vibrational Analysis\n")
310
- try:
311
- freqs = _compute_frequencies(hessian_arr, numbers)
312
- real_freqs = freqs[freqs > 0]
313
- imag_freqs = freqs[freqs < 0]
314
- if len(real_freqs) > 0:
315
- lines.append(f"**Real frequencies:** {len(real_freqs)}\n")
316
- lines.append("```")
317
- for j, f in enumerate(real_freqs):
318
- lines.append(f" {j+1:3d}: {f:10.2f} cm-1")
319
- lines.append("```\n")
320
- if len(imag_freqs) > 0:
321
- lines.append(f"**Imaginary frequencies:** {len(imag_freqs)}\n")
322
- lines.append("```")
323
- for j, f in enumerate(imag_freqs):
324
- lines.append(f" {j+1:3d}: {f:10.2f}i cm-1")
325
- lines.append("```\n")
326
- except Exception as e:
327
- lines.append(f"Frequency analysis failed: {e}\n")
328
- lines.append(f"Hessian shape: `{hessian_arr.shape}`, norm: `{float(np.linalg.norm(hessian_arr)):.4f}`\n")
329
 
330
- lines.append("---")
331
- lines.append("*AIMNet2 wB97M-D3 | [Model card](https://huggingface.co/isayevlab/aimnet2-wb97m-d3) | [Paper](https://doi.org/10.1039/D4SC08572H)*")
 
332
 
333
- return "\n".join(lines), viewer_html
 
 
 
 
 
 
334
 
 
 
 
335
 
336
- def _compute_frequencies(hessian, numbers):
337
- """Compute vibrational frequencies from Hessian matrix.
 
 
 
 
 
 
 
338
 
339
- Returns frequencies in cm^-1 (negative = imaginary).
340
- """
341
- # Atomic masses in amu
342
- MASSES = {
343
- 1: 1.008, 5: 10.81, 6: 12.011, 7: 14.007, 8: 15.999, 9: 18.998,
344
- 14: 28.085, 15: 30.974, 16: 32.06, 17: 35.45, 33: 74.922,
345
- 34: 78.971, 35: 79.904, 46: 106.42, 53: 126.904,
346
- }
347
  n = len(numbers)
 
 
 
 
 
 
348
 
349
- # Reshape Hessian from (N,3,N,3) to (3N, 3N)
350
- if hessian.ndim == 4:
351
- h = hessian.reshape(3 * n, 3 * n)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  else:
353
- h = hessian
354
-
355
- # Mass-weight the Hessian
356
- masses = np.array([MASSES.get(int(z), 1.0) for z in numbers])
357
- mass_vec = np.repeat(masses, 3)
358
- sqrt_mass = np.sqrt(mass_vec)
359
- h_mw = h / np.outer(sqrt_mass, sqrt_mass)
360
-
361
- # Symmetrize
362
- h_mw = 0.5 * (h_mw + h_mw.T)
363
-
364
- # Diagonalize
365
- eigenvalues = np.linalg.eigvalsh(h_mw)
366
-
367
- # Convert: eigenvalue (eV/A^2/amu) -> frequency (cm^-1)
368
- # 1 eV = 1.602176634e-19 J, 1 A = 1e-10 m, 1 amu = 1.66053906660e-27 kg
369
- # freq = sqrt(eigenvalue * eV_to_J / (A_to_m^2 * amu_to_kg)) / (2*pi*c)
370
- eV_to_J = 1.602176634e-19
371
- amu_to_kg = 1.66053906660e-27
372
- A_to_m = 1e-10
373
- c_cm = 2.99792458e10 # speed of light in cm/s
374
-
375
- conv = eV_to_J / (A_to_m**2 * amu_to_kg)
376
- freqs = []
377
- for ev in eigenvalues:
378
- if ev >= 0:
379
- f = np.sqrt(ev * conv) / (2 * np.pi * c_cm)
380
- else:
381
- f = -np.sqrt(-ev * conv) / (2 * np.pi * c_cm)
382
- freqs.append(f)
383
-
384
- freqs = np.array(freqs)
385
- # Remove 6 (or 5 for linear) near-zero translational/rotational modes
386
- freqs = freqs[np.abs(freqs) > 10.0]
387
- return np.sort(freqs)
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
 
390
  # ---------------------------------------------------------------------------
391
  # Gradio UI
392
  # ---------------------------------------------------------------------------
393
 
394
- EXAMPLES = [
395
- ["CCO", "SMILES", 0, True, False],
396
- ["c1ccccc1", "SMILES", 0, True, False],
397
- ["CC(=O)O", "SMILES", 0, True, False],
398
- ["[NH4+]", "SMILES", 1, True, False],
399
- ["CC(=O)[O-]", "SMILES", -1, True, False],
400
- ["O=C(O)c1ccccc1","SMILES", 0, True, False],
401
- ["C1CCCC1", "SMILES", 0, True, True],
 
 
 
 
 
402
  ]
403
 
 
 
 
 
 
 
404
  with gr.Blocks(title="AIMNet2 Demo", theme=gr.themes.Soft()) as demo:
405
  gr.Markdown(
406
  "# AIMNet2 Interactive Demo\n"
407
- "Neural network potential for molecular property prediction: "
408
- "**energy, forces, charges, frequencies**. \n"
409
- "Atoms are colored by predicted partial charge "
410
- "(red = negative, blue = positive)."
411
  )
412
 
 
413
  with gr.Row():
414
  with gr.Column(scale=1):
415
  input_format = gr.Radio(
416
- choices=["SMILES", "XYZ", "PDB"],
417
- value="SMILES",
418
- label="Input Format",
419
  )
420
- input_text = gr.Textbox(
421
- lines=6,
422
- label="Molecule",
423
- placeholder="Enter SMILES (e.g. CCO), XYZ block, or PDB block...",
 
424
  )
425
- charge = gr.Number(value=0, precision=0, label="Charge")
426
- compute_forces = gr.Checkbox(value=True, label="Compute Forces")
427
- compute_hessian = gr.Checkbox(value=False, label="Compute Hessian & Frequencies")
428
- run_btn = gr.Button("Run AIMNet2", variant="primary")
429
 
430
- with gr.Column(scale=1):
431
- viewer = gr.HTML(label="3D Structure", value="<div style='height:400px;display:flex;align-items:center;justify-content:center;color:#aaa;'>Run a calculation to see the 3D structure</div>")
432
- output = gr.Markdown(label="Results")
 
 
 
 
 
 
 
 
433
 
434
- gr.Examples(
435
- examples=EXAMPLES,
436
- inputs=[input_text, input_format, charge, compute_forces, compute_hessian],
437
- label="Example Molecules",
438
  )
439
 
440
- run_btn.click(
441
- fn=predict,
442
- inputs=[input_text, input_format, charge, compute_forces, compute_hessian],
443
- outputs=[output, viewer],
444
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  if __name__ == "__main__":
447
  demo.launch()
 
1
+ """AIMNet2 Interactive Demo v2.
2
 
3
+ 3D visualization, geometry optimization, vibrational analysis, charge coloring.
4
+ https://huggingface.co/spaces/isayevlab/aimnet2-demo
5
  """
6
 
7
+ from __future__ import annotations
 
8
 
9
+ import html
10
+ import json
11
+ import tempfile
12
+ import time
13
+ from pathlib import Path
14
+
15
+ import gradio as gr
16
  import numpy as np
17
+ import plotly.graph_objects as go
18
  import torch
19
+ from plotly.subplots import make_subplots
20
 
21
  torch.set_num_threads(2)
22
 
 
 
23
  # ---------------------------------------------------------------------------
24
  # Constants
25
  # ---------------------------------------------------------------------------
26
  MAX_ATOMS = 200
27
+ MAX_ATOMS_OPT = 50
28
  MAX_ATOMS_HESSIAN = 50
29
+ REQUEST_TIMEOUT = 90 # seconds cumulative per request
30
+ OPT_TIMEOUT = 85 # leave margin for Hessian
31
 
32
  HARTREE_TO_EV = 27.211386024367243
33
  EV_TO_KCAL = 23.06054783
 
39
  }
40
  SYMBOL_TO_NUM = {v: k for k, v in ELEMENT_SYMBOLS.items()}
41
 
42
+ # Atomic masses in amu (IUPAC 2021)
43
+ ATOMIC_MASSES = {
44
+ 1: 1.008, 5: 10.81, 6: 12.011, 7: 14.007, 8: 15.999, 9: 18.998,
45
+ 14: 28.085, 15: 30.974, 16: 32.06, 17: 35.45, 33: 74.922,
46
+ 34: 78.971, 35: 79.904, 46: 106.42, 53: 126.904,
 
47
  }
48
 
49
+ # Unit conversion: eigenvalue (eV/A^2/amu) -> s^-2
50
+ _EV_TO_J = 1.602176634e-19
51
+ _AMU_TO_KG = 1.66053906660e-27
52
+ _A_TO_M = 1e-10
53
+ _C_CM = 2.99792458e10 # speed of light in cm/s
54
+ _FREQ_CONV = _EV_TO_J / (_A_TO_M**2 * _AMU_TO_KG) # eV/(A^2*amu) -> s^-2
55
+
56
+
57
  # ---------------------------------------------------------------------------
58
+ # Model loader (eager, singleton)
59
  # ---------------------------------------------------------------------------
60
+ BASE_CALC = None
61
 
62
 
63
+ def get_base_calc():
64
+ """Return shared AIMNet2Calculator singleton (thread-safe for read-only use)."""
65
+ global BASE_CALC
66
+ if BASE_CALC is None:
67
  from aimnet.calculators import AIMNet2Calculator
68
+ BASE_CALC = AIMNet2Calculator("isayevlab/aimnet2-wb97m-d3", device="cpu")
69
+ return BASE_CALC
70
 
 
 
 
71
 
72
+ def make_ase_calc(charge: int = 0):
73
+ """Create a fresh AIMNet2ASE wrapper per request (concurrency-safe)."""
74
+ from aimnet.calculators.aimnet2ase import AIMNet2ASE
75
+ return AIMNet2ASE(get_base_calc(), charge=charge)
76
 
 
 
 
77
 
78
+ # Eager-load model at import time (during Space startup, not first request)
79
+ try:
80
+ get_base_calc()
81
+ except Exception:
82
+ pass # Will fail on first request with a clear error instead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
 
85
  # ---------------------------------------------------------------------------
86
  # Parsers
87
  # ---------------------------------------------------------------------------
88
 
89
+ def parse_smiles(smiles: str) -> tuple[np.ndarray, np.ndarray, int]:
90
+ """Parse SMILES -> (coords, numbers, formal_charge)."""
91
  from rdkit import Chem
92
  from rdkit.Chem import AllChem
93
 
94
  mol = Chem.MolFromSmiles(smiles.strip())
95
  if mol is None:
96
+ raise ValueError(f"Invalid SMILES: {smiles!r}")
97
  formal_charge = Chem.GetFormalCharge(mol)
98
  mol = Chem.AddHs(mol)
99
+ if AllChem.EmbedMolecule(mol, AllChem.ETKDGv3()) == -1:
100
+ raise ValueError("Failed to generate 3D coordinates. Try a different molecule.")
 
101
  AllChem.MMFFOptimizeMolecule(mol)
102
  conf = mol.GetConformer()
103
+ coords = np.array([conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())])
104
+ numbers = np.array([a.GetAtomicNum() for a in mol.GetAtoms()])
105
  return coords, numbers, formal_charge
106
 
107
 
108
+ def parse_xyz(text: str) -> tuple[np.ndarray, np.ndarray]:
109
+ """Parse XYZ format text -> (coords, numbers)."""
110
+ lines = [l.strip() for l in text.strip().splitlines()]
111
+ start = 2 if lines and lines[0].isdigit() else 0
112
+ coords_list, nums_list = [], []
 
 
113
  for line in lines[start:]:
114
  if not line:
115
  continue
 
119
  sym = parts[0].capitalize()
120
  if sym not in SYMBOL_TO_NUM:
121
  raise ValueError(f"Unknown element: {sym!r}")
122
+ nums_list.append(SYMBOL_TO_NUM[sym])
123
  coords_list.append([float(parts[1]), float(parts[2]), float(parts[3])])
124
  if not coords_list:
125
  raise ValueError("No atoms found in XYZ input.")
126
+ return np.array(coords_list), np.array(nums_list)
127
 
128
 
129
+ def parse_pdb(text: str) -> tuple[np.ndarray, np.ndarray]:
130
+ """Parse PDB format text -> (coords, numbers)."""
131
+ coords_list, nums_list = [], []
132
+ for line in text.splitlines():
133
  if not line.startswith(("ATOM", "HETATM")):
134
  continue
135
  try:
 
142
  elem = elem.capitalize()
143
  if elem not in SYMBOL_TO_NUM:
144
  raise ValueError(f"Unknown element in PDB: {elem!r}")
145
+ nums_list.append(SYMBOL_TO_NUM[elem])
146
  coords_list.append([x, y, z])
147
  if not coords_list:
148
  raise ValueError("No ATOM/HETATM records found.")
149
+ return np.array(coords_list), np.array(nums_list)
150
+
151
+
152
+ def parse_input(text: str, fmt: str) -> tuple[np.ndarray, np.ndarray, str]:
153
+ """Parse molecule input. Returns (coords, numbers, warning_str)."""
154
+ warning = ""
155
+ if fmt == "SMILES":
156
+ coords, numbers, _fc = parse_smiles(text)
157
+ elif fmt == "XYZ":
158
+ coords, numbers = parse_xyz(text)
159
+ elif fmt == "PDB":
160
+ coords, numbers = parse_pdb(text)
161
+ else:
162
+ raise ValueError(f"Unknown format: {fmt}")
163
+ return coords, numbers, warning
164
+
165
+
166
+ def handle_file_upload(file_obj) -> tuple[str, str]:
167
+ """Process uploaded file. Returns (text_content, format_name).
168
+
169
+ Populates the text input and sets format radio.
170
+ """
171
+ if file_obj is None:
172
+ return "", "SMILES"
173
+ path = Path(file_obj.name if hasattr(file_obj, "name") else file_obj)
174
+ suffix = path.suffix.lower()
175
+ text = path.read_text()
176
+
177
+ if suffix == ".xyz":
178
+ return text, "XYZ"
179
+ elif suffix == ".pdb":
180
+ return text, "PDB"
181
+ elif suffix in (".sdf", ".mol"):
182
+ from rdkit import Chem
183
+ from rdkit.Chem import AllChem
184
+ suppl = Chem.SDMolSupplier(str(path), removeHs=False)
185
+ mol = next(suppl, None)
186
+ if mol is None:
187
+ raise ValueError("Could not read SDF file.")
188
+ if mol.GetNumConformers() == 0:
189
+ mol = Chem.AddHs(mol)
190
+ AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
191
+ # Convert to XYZ text
192
+ conf = mol.GetConformer()
193
+ n = mol.GetNumAtoms()
194
+ xyz_lines = [str(n), f"Converted from {path.name}"]
195
+ for i in range(n):
196
+ pos = conf.GetAtomPosition(i)
197
+ sym = mol.GetAtomWithIdx(i).GetSymbol()
198
+ xyz_lines.append(f"{sym} {pos.x:.6f} {pos.y:.6f} {pos.z:.6f}")
199
+ return "\n".join(xyz_lines), "XYZ"
200
+ else:
201
+ raise ValueError(f"Unsupported file type: {suffix}")
202
+
203
+
204
+ # ---------------------------------------------------------------------------
205
+ # 3D Viewer (iframe + 3Dmol.js)
206
+ # ---------------------------------------------------------------------------
207
+
208
+ def _charge_to_hex(q: float, qlim: float) -> str:
209
+ """Map charge to color: red (negative) -> white (0) -> blue (positive)."""
210
+ t = np.clip((q + qlim) / (2 * qlim), 0, 1)
211
+ if t < 0.5:
212
+ s = t * 2
213
+ r, g, b = 1.0, s, s
214
+ else:
215
+ s = (t - 0.5) * 2
216
+ r, g, b = 1.0 - s, 1.0 - s, 1.0
217
+ return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
218
+
219
+
220
+ def build_viewer_html(
221
+ coords: np.ndarray,
222
+ numbers: np.ndarray,
223
+ charges: np.ndarray | None = None,
224
+ height: int = 420,
225
+ ) -> str:
226
+ """Build iframe HTML with 3Dmol.js viewer and CPK/charge toggle."""
227
+ n = len(numbers)
228
+
229
+ # Build XYZ string
230
+ xyz_lines = [str(n), "AIMNet2"]
231
+ for i in range(n):
232
+ sym = ELEMENT_SYMBOLS.get(int(numbers[i]), "X")
233
+ x, y, z = coords[i]
234
+ xyz_lines.append(f"{sym} {x:.6f} {y:.6f} {z:.6f}")
235
+ xyz_string = "\n".join(xyz_lines)
236
+
237
+ # Build per-atom charge color JS (only if charges and <= 100 atoms)
238
+ charge_js = ""
239
+ has_toggle = charges is not None and n <= 100
240
+ if has_toggle:
241
+ qlim = max(float(np.max(np.abs(charges))), 0.3)
242
+ charge_styles = []
243
+ for i in range(n):
244
+ c = _charge_to_hex(float(charges[i]), qlim)
245
+ charge_styles.append(
246
+ f'viewer.getModel().setAtomStyle({{index:{i}}},'
247
+ f'{{stick:{{radius:0.15}},sphere:{{scale:0.25,color:"{c}"}}}});'
248
+ )
249
+ charge_js = "\n".join(charge_styles)
250
+
251
+ toggle_btn = ""
252
+ toggle_fn = ""
253
+ if has_toggle:
254
+ toggle_btn = (
255
+ '<button id="toggle-btn" onclick="toggleColors()" '
256
+ 'style="position:absolute;top:8px;right:8px;z-index:10;'
257
+ 'padding:4px 10px;font-size:12px;cursor:pointer;'
258
+ 'border:1px solid #ccc;border-radius:4px;background:#f8f8f8;">'
259
+ 'Color by charge</button>'
260
+ )
261
+ toggle_fn = f"""
262
+ var cpkMode = true;
263
+ function setCPK() {{
264
+ viewer.setStyle({{}}, {{stick:{{radius:0.15}}, sphere:{{scale:0.25, colorscheme:"Jmol"}}}});
265
+ viewer.render();
266
+ }}
267
+ function setCharges() {{
268
+ {charge_js}
269
+ viewer.render();
270
+ }}
271
+ function toggleColors() {{
272
+ cpkMode = !cpkMode;
273
+ if (cpkMode) {{
274
+ setCPK();
275
+ document.getElementById("toggle-btn").textContent = "Color by charge";
276
+ }} else {{
277
+ setCharges();
278
+ document.getElementById("toggle-btn").textContent = "CPK colors";
279
+ }}
280
+ }}
281
+ """
282
+
283
+ inner_html = f"""<!DOCTYPE html>
284
+ <html><head>
285
+ <meta charset="utf-8">
286
+ <style>
287
+ body {{ margin:0; overflow:hidden; font-family:sans-serif; }}
288
+ #viewer {{ width:100%; height:{height}px; position:relative; }}
289
+ #fallback {{ display:none; padding:20px; color:#888; text-align:center; }}
290
+ </style>
291
+ </head><body>
292
+ <div id="viewer"></div>
293
+ {toggle_btn}
294
+ <div id="fallback">3D viewer unavailable. Results are shown below.</div>
295
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js"></script>
296
+ <script>
297
+ try {{
298
+ var xyz = {json.dumps(xyz_string)};
299
+ var viewer = $3Dmol.createViewer("viewer", {{backgroundColor:"white"}});
300
+ viewer.addModel(xyz, "xyz");
301
+ viewer.setStyle({{}}, {{stick:{{radius:0.15}}, sphere:{{scale:0.25, colorscheme:"Jmol"}}}});
302
+ viewer.zoomTo();
303
+ viewer.render();
304
+ {toggle_fn}
305
+ }} catch(e) {{
306
+ document.getElementById("viewer").style.display = "none";
307
+ document.getElementById("fallback").style.display = "block";
308
+ }}
309
+ </script>
310
+ </body></html>"""
311
+
312
+ escaped = html.escape(inner_html, quote=True)
313
+ return (
314
+ f'<iframe srcdoc="{escaped}" width="100%" height="{height + 30}" '
315
+ f'frameborder="0" sandbox="allow-scripts" '
316
+ f'style="border:1px solid #eee;border-radius:8px;"></iframe>'
317
+ )
318
 
319
 
320
  # ---------------------------------------------------------------------------
321
+ # Frequency computation
322
+ # ---------------------------------------------------------------------------
323
+
324
+ def is_linear(coords: np.ndarray, numbers: np.ndarray, tol: float = 1e-3) -> bool:
325
+ """Check if molecule is linear via moment of inertia tensor."""
326
+ masses = np.array([ATOMIC_MASSES.get(int(z), 1.0) for z in numbers])
327
+ com = np.average(coords, weights=masses, axis=0)
328
+ r = coords - com
329
+ I = np.zeros((3, 3))
330
+ for m, ri in zip(masses, r):
331
+ I += m * (np.dot(ri, ri) * np.eye(3) - np.outer(ri, ri))
332
+ eigvals = np.linalg.eigvalsh(I)
333
+ return eigvals[0] / max(eigvals[-1], 1e-30) < tol
334
+
335
+
336
+ def compute_frequencies(
337
+ hessian: np.ndarray,
338
+ numbers: np.ndarray,
339
+ coords: np.ndarray,
340
+ ) -> tuple[np.ndarray, int]:
341
+ """Compute vibrational frequencies from Hessian.
342
+
343
+ Parameters
344
+ ----------
345
+ hessian : ndarray, shape (N,3,N,3) or (3N,3N)
346
+ Hessian in eV/A^2.
347
+ numbers : ndarray, shape (N,)
348
+ Atomic numbers.
349
+ coords : ndarray, shape (N,3)
350
+ Atomic positions (for linearity check).
351
+
352
+ Returns
353
+ -------
354
+ freqs_cm : ndarray
355
+ Vibrational frequencies in cm^-1. Negative = imaginary.
356
+ n_imag : int
357
+ Number of imaginary frequencies.
358
+ """
359
+ n = len(numbers)
360
+ H = hessian.reshape(3 * n, 3 * n)
361
+
362
+ # Mass-weight
363
+ masses = np.array([ATOMIC_MASSES.get(int(z), 1.0) for z in numbers])
364
+ masses_3n = np.repeat(masses, 3)
365
+ H_mw = H / np.sqrt(np.outer(masses_3n, masses_3n))
366
+ H_mw = 0.5 * (H_mw + H_mw.T) # symmetrize
367
+
368
+ eigenvalues = np.linalg.eigvalsh(H_mw)
369
+
370
+ # Convert to cm^-1
371
+ freqs = (
372
+ np.sign(eigenvalues)
373
+ * np.sqrt(np.abs(eigenvalues) * _FREQ_CONV)
374
+ / (2 * np.pi * _C_CM)
375
+ )
376
+
377
+ # Remove translation/rotation modes (count-based)
378
+ n_tr = 5 if is_linear(coords, numbers) else 6
379
+ sorted_idx = np.argsort(np.abs(freqs))
380
+ vib_idx = sorted_idx[n_tr:]
381
+ freqs_vib = np.sort(freqs[vib_idx])
382
+
383
+ n_imag = int(np.sum(freqs_vib < -10))
384
+ return freqs_vib, n_imag
385
+
386
+
387
+ # ---------------------------------------------------------------------------
388
+ # Plotting
389
+ # ---------------------------------------------------------------------------
390
+
391
+ def make_frequency_plot(freqs: np.ndarray) -> go.Figure:
392
+ """Create Plotly stick spectrum of vibrational frequencies."""
393
+ real = freqs[freqs > 0]
394
+ fig = go.Figure()
395
+ if len(real) > 0:
396
+ fig.add_trace(go.Bar(
397
+ x=real, y=np.ones_like(real),
398
+ width=3, marker_color="steelblue",
399
+ hovertemplate="%{x:.1f} cm\u207b\u00b9<extra></extra>",
400
+ ))
401
+ fig.update_layout(
402
+ xaxis_title="Frequency (cm\u207b\u00b9)",
403
+ yaxis_visible=False,
404
+ height=200, margin=dict(l=40, r=20, t=30, b=40),
405
+ title="Vibrational Spectrum",
406
+ showlegend=False,
407
+ )
408
+ return fig
409
+
410
+
411
+ def make_convergence_plot(trajectory: list[dict]) -> go.Figure:
412
+ """Create dual-axis convergence plot (energy + max force vs step)."""
413
+ steps = [t["step"] for t in trajectory]
414
+ energies = [t["energy"] for t in trajectory]
415
+ fmaxes = [t["fmax"] for t in trajectory]
416
+
417
+ fig = make_subplots(specs=[[{"secondary_y": True}]])
418
+ fig.add_trace(
419
+ go.Scatter(x=steps, y=energies, name="Energy (eV)", mode="lines+markers",
420
+ marker=dict(size=4), line=dict(color="steelblue")),
421
+ secondary_y=False,
422
+ )
423
+ fig.add_trace(
424
+ go.Scatter(x=steps, y=fmaxes, name="Max |F| (eV/\u00c5)", mode="lines+markers",
425
+ marker=dict(size=4), line=dict(color="firebrick")),
426
+ secondary_y=True,
427
+ )
428
+ fig.update_xaxes(title_text="Step")
429
+ fig.update_yaxes(title_text="Energy (eV)", secondary_y=False)
430
+ fig.update_yaxes(title_text="Max |F| (eV/\u00c5)", secondary_y=True)
431
+ fig.update_layout(
432
+ height=280, margin=dict(l=60, r=60, t=30, b=40),
433
+ legend=dict(x=0.5, y=1.15, xanchor="center", orientation="h"),
434
+ )
435
+ return fig
436
+
437
+
438
+ # ---------------------------------------------------------------------------
439
+ # Geometry optimization
440
+ # ---------------------------------------------------------------------------
441
+
442
+ def run_optimization(
443
+ atoms,
444
+ max_steps: int,
445
+ fmax_target: float,
446
+ timeout: float = OPT_TIMEOUT,
447
+ ) -> tuple[list[dict], bool, float]:
448
+ """Run LBFGS optimization with timeout.
449
+
450
+ Returns (trajectory, converged, wall_time).
451
+ Reads from ASE cache to avoid double-computing.
452
+ """
453
+ from ase.optimize import LBFGS
454
+
455
+ opt = LBFGS(atoms, logfile=None)
456
+ trajectory = []
457
+ t0 = time.time()
458
+ converged = False
459
+
460
+ for step in range(max_steps):
461
+ if time.time() - t0 > timeout:
462
+ break
463
+ opt.step()
464
+ e = float(atoms.calc.results["energy"])
465
+ f = atoms.calc.results["forces"]
466
+ fmax = float(np.max(np.linalg.norm(f, axis=1)))
467
+ trajectory.append({"step": step + 1, "energy": e, "fmax": fmax})
468
+ if fmax < fmax_target:
469
+ converged = True
470
+ break
471
+
472
+ return trajectory, converged, time.time() - t0
473
+
474
+
475
+ # ---------------------------------------------------------------------------
476
+ # Reproduction script generator
477
+ # ---------------------------------------------------------------------------
478
+
479
+ def _fmt_array(arr: np.ndarray, name: str) -> str:
480
+ """Format numpy array as valid Python code."""
481
+ if arr.ndim == 1:
482
+ return f"{name} = {arr.tolist()!r}"
483
+ # 2D
484
+ rows = []
485
+ for row in arr:
486
+ rows.append(" [" + ", ".join(f"{v:.6f}" for v in row) + "],")
487
+ return f"{name} = np.array([\n" + "\n".join(rows) + "\n])"
488
+
489
+
490
+ def generate_script(
491
+ coords: np.ndarray,
492
+ numbers: np.ndarray,
493
+ charge: int,
494
+ task: str = "single_point",
495
+ max_steps: int = 30,
496
+ fmax: float = 0.05,
497
+ compute_hessian: bool = False,
498
+ ) -> str:
499
+ """Generate Python reproduction script."""
500
+ lines = [
501
+ "# AIMNet2 calculation",
502
+ "# Generated by https://huggingface.co/spaces/isayevlab/aimnet2-demo",
503
+ "from aimnet.calculators import AIMNet2Calculator",
504
+ "from aimnet.calculators.aimnet2ase import AIMNet2ASE",
505
+ "from ase import Atoms",
506
+ "import numpy as np",
507
+ "",
508
+ _fmt_array(coords, "coords"),
509
+ f"numbers = {numbers.tolist()!r}",
510
+ f"charge = {charge}",
511
+ "",
512
+ 'calc = AIMNet2ASE(AIMNet2Calculator("isayevlab/aimnet2-wb97m-d3"), charge=charge)',
513
+ "atoms = Atoms(numbers=numbers, positions=coords)",
514
+ "atoms.calc = calc",
515
+ "",
516
+ ]
517
+
518
+ if task == "optimize":
519
+ lines += [
520
+ "from ase.optimize import LBFGS",
521
+ f"opt = LBFGS(atoms, logfile='-')",
522
+ f"opt.run(fmax={fmax}, steps={max_steps})",
523
+ "",
524
+ "energy = atoms.get_potential_energy()",
525
+ 'print(f"Optimized energy: {energy:.6f} eV")',
526
+ 'print(f"Max force: {max(np.linalg.norm(atoms.get_forces(), axis=1)):.6f} eV/A")',
527
+ ]
528
+ else:
529
+ lines += [
530
+ "energy = atoms.get_potential_energy()",
531
+ "forces = atoms.get_forces()",
532
+ 'charges = atoms.calc.results["charges"]',
533
+ 'print(f"Energy: {energy:.6f} eV")',
534
+ ]
535
+
536
+ if compute_hessian:
537
+ lines += [
538
+ "",
539
+ "# Hessian & frequencies",
540
+ "base_calc = calc.base_calc",
541
+ 'hess_result = base_calc({"coord": atoms.get_positions(), '
542
+ '"numbers": atoms.numbers, "charge": float(charge)}, hessian=True)',
543
+ 'hessian = hess_result["hessian"].detach().cpu().numpy()',
544
+ "# Diagonalize mass-weighted Hessian for frequencies (see demo source for details)",
545
+ ]
546
+
547
+ return "\n".join(lines)
548
+
549
+
550
+ # ---------------------------------------------------------------------------
551
+ # XYZ download helper
552
+ # ---------------------------------------------------------------------------
553
+
554
+ def write_xyz_file(coords: np.ndarray, numbers: np.ndarray,
555
+ charges: np.ndarray | None = None,
556
+ comment: str = "AIMNet2") -> str:
557
+ """Write XYZ to a temp file and return the path."""
558
+ n = len(numbers)
559
+ lines = [str(n), comment]
560
+ for i in range(n):
561
+ sym = ELEMENT_SYMBOLS.get(int(numbers[i]), "X")
562
+ x, y, z = coords[i]
563
+ q_str = f" {charges[i]:+.4f}" if charges is not None else ""
564
+ lines.append(f"{sym:2s} {x:12.6f} {y:12.6f} {z:12.6f}{q_str}")
565
+ tmp = tempfile.NamedTemporaryFile(suffix=".xyz", delete=False, mode="w")
566
+ tmp.write("\n".join(lines))
567
+ tmp.close()
568
+ return tmp.name
569
+
570
+
571
+ # ---------------------------------------------------------------------------
572
+ # Tab 1: Single-point calculation
573
  # ---------------------------------------------------------------------------
574
 
575
  def predict(input_text, input_format, charge, compute_forces, compute_hessian):
576
+ """Run single-point calculation. Returns (markdown, viewer_html, freq_plot, xyz_file, script)."""
577
  charge = int(charge)
578
+ empty = ("", "", None, None, "")
579
 
580
+ # Parse
 
581
  try:
582
+ coords, numbers, warning = parse_input(input_text, input_format)
 
 
 
 
 
 
 
 
 
 
 
 
583
  except Exception as e:
584
+ return (f"**Parse error:** {e}", *empty[1:])
585
 
586
+ n = len(numbers)
587
+ if n > MAX_ATOMS:
588
+ return (f"**Error:** {n} atoms exceeds limit of {MAX_ATOMS}.", *empty[1:])
589
+ if compute_hessian and n > MAX_ATOMS_HESSIAN:
590
+ return (f"**Error:** Hessian limited to {MAX_ATOMS_HESSIAN} atoms ({n} given).", *empty[1:])
591
 
592
+ # Validate elements
593
+ unsupported = sorted({int(z) for z in numbers} - set(ELEMENT_SYMBOLS))
594
  if unsupported:
595
+ return (f"**Error:** Unsupported elements: {unsupported}", *empty[1:])
596
+
597
+ # SMILES charge validation
598
+ smiles_warn = ""
599
+ if input_format == "SMILES":
600
+ _, _, fc = parse_smiles(input_text) # already parsed, just get charge
601
+ if fc != charge:
602
+ smiles_warn = (
603
+ f"\n> **Warning:** SMILES formal charge ({fc:+d}) != "
604
+ f"supplied charge ({charge:+d}). Using supplied charge.\n"
605
+ )
606
 
607
+ # Calculate
608
  try:
609
+ ase_calc = make_ase_calc(charge)
 
 
610
  from ase import Atoms
611
  symbols = [ELEMENT_SYMBOLS[int(z)] for z in numbers]
612
  atoms = Atoms(symbols=symbols, positions=coords)
613
+ atoms.calc = ase_calc
614
 
615
  atoms.get_potential_energy()
616
+ energy_ev = float(ase_calc.results["energy"])
617
+ charges_arr = ase_calc.results.get("charges")
618
+
619
+ if not np.isfinite(energy_ev):
620
+ return ("**Error:** Model produced NaN/Inf. Molecule may be outside training domain.", *empty[1:])
621
 
622
  forces_arr = None
623
  if compute_forces:
624
  atoms.get_forces()
625
+ forces_arr = ase_calc.results["forces"]
626
 
627
  hessian_arr = None
628
+ freqs = None
629
+ n_imag = 0
630
  if compute_hessian:
631
  data = {"coord": coords, "numbers": numbers, "charge": float(charge)}
632
+ hess_result = get_base_calc()(data, hessian=True)
633
+ hessian_arr = hess_result["hessian"].detach().cpu().numpy()
634
+ freqs, n_imag = compute_frequencies(hessian_arr, numbers, coords)
635
 
636
  except Exception as e:
637
  import traceback
638
+ return (f"**Calculation error:** {e}\n```\n{traceback.format_exc()}\n```", *empty[1:])
639
 
640
+ # Build outputs
641
+ viewer_html = build_viewer_html(coords, numbers, charges_arr)
642
 
643
+ # Results markdown
644
  energy_kcal = energy_ev * EV_TO_KCAL
645
  energy_ha = energy_ev / HARTREE_TO_EV
646
+ md = []
647
+ md.append("## AIMNet2 Results\n")
648
+ if smiles_warn:
649
+ md.append(smiles_warn)
650
+ md.append(f"**Atoms:** {n} | **Charge:** {charge:+d}\n")
651
+
652
+ md.append("### Energy\n| Unit | Value |\n|------|------:|")
653
+ md.append(f"| eV | {energy_ev:.6f} |")
654
+ md.append(f"| kcal/mol | {energy_kcal:.4f} |")
655
+ md.append(f"| Hartree | {energy_ha:.8f} |\n")
656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  if charges_arr is not None:
658
+ md.append("### Partial Charges (e)\n| # | Elem | Charge |\n|--:|:----:|-------:|")
 
 
659
  for i, (z, q) in enumerate(zip(numbers, charges_arr)):
660
+ sym = ELEMENT_SYMBOLS.get(int(z), "?")
661
+ md.append(f"| {i+1} | {sym} | {q:+.4f} |")
662
+ md.append(f"\n*Sum: {float(np.sum(charges_arr)):+.4f} e*\n")
 
663
 
 
664
  if forces_arr is not None:
665
  max_f = float(np.max(np.linalg.norm(forces_arr, axis=1)))
666
+ rms_f = float(np.sqrt(np.mean(forces_arr**2)))
667
+ md.append("### Forces (eV/A)\n| Metric | Value |\n|--------|------:|")
668
+ md.append(f"| Max |F| | {max_f:.6f} |")
669
+ md.append(f"| RMS | {rms_f:.6f} |")
 
 
 
670
  if input_format == "SMILES":
671
+ md.append("\n> *Geometry from MMFF, not AIMNet2-optimized.*\n")
672
+
673
+ freq_plot = None
674
+ if freqs is not None:
675
+ real_f = freqs[freqs > 0]
676
+ imag_f = freqs[freqs < 0]
677
+ md.append("### Vibrational Frequencies\n")
678
+ if max_f > 0.05 if forces_arr is not None else True:
679
+ md.append("> *Frequencies at non-stationary point. Low modes may be unreliable.*\n")
680
+ if n_imag > 0:
681
+ md.append(f"> **{n_imag} imaginary frequency(ies)** -- not a true minimum.\n")
682
+ if len(real_f) > 0:
683
+ md.append("```")
684
+ for j, f in enumerate(real_f):
685
+ md.append(f" {j+1:3d}: {f:10.2f} cm-1")
686
+ md.append("```")
687
+ if len(imag_f) > 0:
688
+ md.append("\nImaginary:\n```")
689
+ for j, f in enumerate(imag_f):
690
+ md.append(f" {j+1:3d}: {abs(f):10.2f}i cm-1")
691
+ md.append("```")
692
+ freq_plot = make_frequency_plot(freqs)
693
+
694
+ md.append("\n---")
695
+ md.append("*AIMNet2 wB97M-D3 | [Model](https://huggingface.co/isayevlab/aimnet2-wb97m-d3) | [Paper](https://doi.org/10.1039/D4SC08572H)*")
696
+
697
+ xyz_file = write_xyz_file(coords, numbers, charges_arr,
698
+ comment=f"Energy: {energy_ev:.6f} eV")
699
+
700
+ script = generate_script(coords, numbers, charge, "single_point",
701
+ compute_hessian=compute_hessian)
702
+
703
+ return "\n".join(md), viewer_html, freq_plot, xyz_file, script
704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
 
706
+ # ---------------------------------------------------------------------------
707
+ # Tab 2: Geometry optimization
708
+ # ---------------------------------------------------------------------------
709
 
710
+ def optimize(input_text, input_format, charge, max_steps, fmax_target,
711
+ compute_freqs):
712
+ """Run geometry optimization. Returns (md, viewer_html, conv_plot, freq_plot, xyz_file, script)."""
713
+ charge = int(charge)
714
+ max_steps = int(max_steps)
715
+ fmax_target = float(fmax_target)
716
+ empty = ("", "", None, None, None, "")
717
 
718
+ # Auto-tighten fmax when frequencies requested
719
+ if compute_freqs and fmax_target > 0.02:
720
+ fmax_target = 0.02
721
 
722
+ # Validate fmax
723
+ if not 0.01 <= fmax_target <= 1.0:
724
+ return ("**Error:** fmax must be between 0.01 and 1.0 eV/A.", *empty[1:])
725
+
726
+ # Parse
727
+ try:
728
+ coords, numbers, _ = parse_input(input_text, input_format)
729
+ except Exception as e:
730
+ return (f"**Parse error:** {e}", *empty[1:])
731
 
 
 
 
 
 
 
 
 
732
  n = len(numbers)
733
+ if n > MAX_ATOMS_OPT:
734
+ return (f"**Error:** Optimization limited to {MAX_ATOMS_OPT} atoms ({n} given).", *empty[1:])
735
+
736
+ unsupported = sorted({int(z) for z in numbers} - set(ELEMENT_SYMBOLS))
737
+ if unsupported:
738
+ return (f"**Error:** Unsupported elements: {unsupported}", *empty[1:])
739
 
740
+ # Optimize
741
+ try:
742
+ ase_calc = make_ase_calc(charge)
743
+ from ase import Atoms
744
+ symbols = [ELEMENT_SYMBOLS[int(z)] for z in numbers]
745
+ atoms = Atoms(symbols=symbols, positions=coords)
746
+ atoms.calc = ase_calc
747
+
748
+ # Initial energy/forces
749
+ atoms.get_potential_energy()
750
+ e0 = float(ase_calc.results["energy"])
751
+ f0 = ase_calc.results["forces"]
752
+ fmax0 = float(np.max(np.linalg.norm(f0, axis=1)))
753
+
754
+ trajectory, converged, wall_time = run_optimization(
755
+ atoms, max_steps, fmax_target
756
+ )
757
+
758
+ opt_coords = atoms.get_positions()
759
+ e_final = trajectory[-1]["energy"] if trajectory else e0
760
+ fmax_final = trajectory[-1]["fmax"] if trajectory else fmax0
761
+ charges_arr = ase_calc.results.get("charges")
762
+
763
+ if not np.isfinite(e_final):
764
+ return ("**Error:** Model produced NaN/Inf during optimization.", *empty[1:])
765
+
766
+ # Frequencies at optimized geometry
767
+ freqs = None
768
+ n_imag = 0
769
+ if compute_freqs:
770
+ data = {
771
+ "coord": opt_coords,
772
+ "numbers": atoms.numbers,
773
+ "charge": float(charge),
774
+ }
775
+ hess_result = get_base_calc()(data, hessian=True)
776
+ hessian = hess_result["hessian"].detach().cpu().numpy()
777
+ freqs, n_imag = compute_frequencies(hessian, atoms.numbers, opt_coords)
778
+
779
+ except Exception as e:
780
+ import traceback
781
+ return (f"**Calculation error:** {e}\n```\n{traceback.format_exc()}\n```", *empty[1:])
782
+
783
+ # Build outputs
784
+ viewer_html = build_viewer_html(opt_coords, numbers, charges_arr)
785
+ conv_plot = make_convergence_plot(trajectory) if trajectory else None
786
+
787
+ md = []
788
+ md.append("## Optimization Results\n")
789
+ status = "Converged" if converged else "Not converged"
790
+ if not converged:
791
+ md.append(f"> **{status}** after {len(trajectory)} steps / {wall_time:.1f}s "
792
+ f"(final fmax: {fmax_final:.4f} eV/A)\n")
793
  else:
794
+ md.append(f"**{status}** in {len(trajectory)} steps ({wall_time:.1f}s)\n")
795
+
796
+ md.append("| Property | Initial | Final |")
797
+ md.append("|----------|--------:|------:|")
798
+ md.append(f"| Energy (eV) | {e0:.6f} | {e_final:.6f} |")
799
+ md.append(f"| Energy (kcal/mol) | {e0*EV_TO_KCAL:.4f} | {e_final*EV_TO_KCAL:.4f} |")
800
+ md.append(f"| Max |F| (eV/A) | {fmax0:.6f} | {fmax_final:.6f} |")
801
+ md.append(f"| dE (eV) | | {e_final - e0:.6f} |")
802
+ md.append("")
803
+
804
+ if charges_arr is not None:
805
+ md.append("### Partial Charges (e)\n| # | Elem | Charge |\n|--:|:----:|-------:|")
806
+ for i, (z, q) in enumerate(zip(numbers, charges_arr)):
807
+ sym = ELEMENT_SYMBOLS.get(int(z), "?")
808
+ md.append(f"| {i+1} | {sym} | {q:+.4f} |")
809
+ md.append(f"\n*Sum: {float(np.sum(charges_arr)):+.4f} e*\n")
810
+
811
+ freq_plot = None
812
+ if freqs is not None:
813
+ real_f = freqs[freqs > 0]
814
+ imag_f = freqs[freqs < 0]
815
+ md.append("### Vibrational Frequencies\n")
816
+ if n_imag > 0:
817
+ md.append(f"> **{n_imag} imaginary frequency(ies)** -- not a true minimum.\n")
818
+ if len(real_f) > 0:
819
+ md.append("```")
820
+ for j, f in enumerate(real_f):
821
+ md.append(f" {j+1:3d}: {f:10.2f} cm-1")
822
+ md.append("```")
823
+ if len(imag_f) > 0:
824
+ md.append("\nImaginary:\n```")
825
+ for j, f in enumerate(imag_f):
826
+ md.append(f" {j+1:3d}: {abs(f):10.2f}i cm-1")
827
+ md.append("```")
828
+ freq_plot = make_frequency_plot(freqs)
829
+
830
+ md.append("\n---")
831
+ md.append("*AIMNet2 wB97M-D3 | [Model](https://huggingface.co/isayevlab/aimnet2-wb97m-d3)*")
832
+
833
+ xyz_file = write_xyz_file(opt_coords, numbers, charges_arr,
834
+ comment=f"Optimized, E={e_final:.6f} eV, fmax={fmax_final:.6f}")
835
+
836
+ script = generate_script(coords, numbers, charge, "optimize",
837
+ max_steps=max_steps, fmax=fmax_target,
838
+ compute_hessian=compute_freqs)
839
+
840
+ return "\n".join(md), viewer_html, conv_plot, freq_plot, xyz_file, script
841
 
842
 
843
  # ---------------------------------------------------------------------------
844
  # Gradio UI
845
  # ---------------------------------------------------------------------------
846
 
847
+ CALC_EXAMPLES = [
848
+ ["CCO", "SMILES", 0, True, False],
849
+ ["c1ccccc1", "SMILES", 0, True, False],
850
+ ["CC(=O)O", "SMILES", 0, True, False],
851
+ ["[NH4+]", "SMILES", 1, True, False],
852
+ ["CC(=O)[O-]", "SMILES", -1, True, False],
853
+ ["O=C(O)c1ccccc1", "SMILES", 0, True, False],
854
+ ["O", "SMILES", 0, True, True],
855
+ ]
856
+
857
+ OPT_EXAMPLES = [
858
+ ["CCO", "SMILES", 0, 30, 0.05, False],
859
+ ["O", "SMILES", 0, 30, 0.05, True],
860
  ]
861
 
862
+ VIEWER_PLACEHOLDER = (
863
+ '<div style="height:420px;display:flex;align-items:center;'
864
+ 'justify-content:center;color:#aaa;border:1px solid #eee;'
865
+ 'border-radius:8px;">Run a calculation to see the 3D structure</div>'
866
+ )
867
+
868
  with gr.Blocks(title="AIMNet2 Demo", theme=gr.themes.Soft()) as demo:
869
  gr.Markdown(
870
  "# AIMNet2 Interactive Demo\n"
871
+ "Neural network potential: **energy, forces, charges, optimization, frequencies**. \n"
872
+ "3D viewer with charge coloring (red = negative, blue = positive)."
 
 
873
  )
874
 
875
+ # --- Shared input region ---
876
  with gr.Row():
877
  with gr.Column(scale=1):
878
  input_format = gr.Radio(
879
+ ["SMILES", "XYZ", "PDB"], value="SMILES", label="Input Format"
 
 
880
  )
881
+ input_text = gr.Textbox(lines=6, label="Molecule",
882
+ placeholder="SMILES, XYZ block, or PDB...")
883
+ file_upload = gr.File(
884
+ label="Or upload file",
885
+ file_types=[".xyz", ".pdb", ".sdf", ".mol"],
886
  )
887
+ charge_input = gr.Number(value=0, precision=0, label="Charge")
 
 
 
888
 
889
+ # File upload handler
890
+ def on_file_upload(file_obj):
891
+ if file_obj is None:
892
+ return gr.update(), gr.update()
893
+ try:
894
+ text, fmt = handle_file_upload(file_obj)
895
+ gr.Info(f"Loaded file ({fmt} format)")
896
+ return gr.update(value=text), gr.update(value=fmt)
897
+ except Exception as e:
898
+ gr.Warning(f"File upload failed: {e}")
899
+ return gr.update(), gr.update()
900
 
901
+ file_upload.change(
902
+ on_file_upload, inputs=[file_upload], outputs=[input_text, input_format]
 
 
903
  )
904
 
905
+ # --- Tabs ---
906
+ with gr.Tabs():
907
+ # ===== Tab 1: Calculate =====
908
+ with gr.TabItem("Calculate"):
909
+ with gr.Row():
910
+ with gr.Column(scale=1):
911
+ calc_forces = gr.Checkbox(value=True, label="Compute Forces")
912
+ calc_hessian = gr.Checkbox(value=False,
913
+ label="Compute Hessian & Frequencies")
914
+ calc_btn = gr.Button("Calculate", variant="primary")
915
+
916
+ gr.Examples(
917
+ examples=CALC_EXAMPLES,
918
+ inputs=[input_text, input_format, charge_input,
919
+ calc_forces, calc_hessian],
920
+ label="Examples",
921
+ )
922
+
923
+ with gr.Column(scale=2):
924
+ calc_viewer = gr.HTML(value=VIEWER_PLACEHOLDER, label="3D Structure")
925
+ calc_results = gr.Markdown(label="Results")
926
+ calc_freq_plot = gr.Plot(label="Frequency Spectrum", visible=False)
927
+ with gr.Accordion("Download XYZ", open=False):
928
+ calc_xyz = gr.File(label="XYZ file", interactive=False)
929
+ with gr.Accordion("Python code to reproduce", open=False):
930
+ calc_script = gr.Code(language="python", label="Script")
931
+
932
+ def calc_wrapper(text, fmt, charge, forces, hessian):
933
+ md, viewer, fplot, xyz, script = predict(text, fmt, charge, forces, hessian)
934
+ return (
935
+ md,
936
+ viewer or VIEWER_PLACEHOLDER,
937
+ gr.update(value=fplot, visible=fplot is not None),
938
+ xyz,
939
+ script,
940
+ )
941
+
942
+ calc_btn.click(
943
+ calc_wrapper,
944
+ inputs=[input_text, input_format, charge_input, calc_forces, calc_hessian],
945
+ outputs=[calc_results, calc_viewer, calc_freq_plot, calc_xyz, calc_script],
946
+ )
947
+
948
+ # ===== Tab 2: Optimize =====
949
+ with gr.TabItem("Optimize"):
950
+ with gr.Row():
951
+ with gr.Column(scale=1):
952
+ opt_steps = gr.Slider(10, 50, value=30, step=1, label="Max Steps")
953
+ opt_fmax = gr.Number(value=0.05, label="Convergence fmax (eV/A)",
954
+ minimum=0.01, maximum=1.0)
955
+ opt_freqs = gr.Checkbox(value=False,
956
+ label="Compute frequencies at minimum")
957
+ opt_btn = gr.Button("Optimize", variant="primary")
958
+
959
+ gr.Examples(
960
+ examples=OPT_EXAMPLES,
961
+ inputs=[input_text, input_format, charge_input,
962
+ opt_steps, opt_fmax, opt_freqs],
963
+ label="Examples",
964
+ )
965
+
966
+ with gr.Column(scale=2):
967
+ opt_viewer = gr.HTML(value=VIEWER_PLACEHOLDER, label="Optimized Structure")
968
+ opt_conv_plot = gr.Plot(label="Convergence")
969
+ opt_results = gr.Markdown(label="Results")
970
+ opt_freq_plot = gr.Plot(label="Frequency Spectrum", visible=False)
971
+ with gr.Accordion("Download optimized XYZ", open=False):
972
+ opt_xyz = gr.File(label="XYZ file", interactive=False)
973
+ with gr.Accordion("Python code to reproduce", open=False):
974
+ opt_script = gr.Code(language="python", label="Script")
975
+
976
+ def opt_wrapper(text, fmt, charge, steps, fmax, freqs):
977
+ md, viewer, conv, fplot, xyz, script = optimize(
978
+ text, fmt, charge, steps, fmax, freqs
979
+ )
980
+ return (
981
+ md,
982
+ viewer or VIEWER_PLACEHOLDER,
983
+ conv,
984
+ gr.update(value=fplot, visible=fplot is not None),
985
+ xyz,
986
+ script,
987
+ )
988
+
989
+ opt_btn.click(
990
+ opt_wrapper,
991
+ inputs=[input_text, input_format, charge_input,
992
+ opt_steps, opt_fmax, opt_freqs],
993
+ outputs=[opt_results, opt_viewer, opt_conv_plot,
994
+ opt_freq_plot, opt_xyz, opt_script],
995
+ )
996
 
997
  if __name__ == "__main__":
998
  demo.launch()
requirements.txt CHANGED
@@ -6,3 +6,4 @@ ase==3.27.0
6
  rdkit
7
  numpy
8
  gradio>=5.0
 
 
6
  rdkit
7
  numpy
8
  gradio>=5.0
9
+ plotly>=5.0