aimnet2-demo / app.py
isayev's picture
Upload app.py with huggingface_hub
7ae2e87 verified
"""AIMNet2 Interactive Demo v2.
3D visualization, geometry optimization, vibrational analysis, charge coloring.
https://huggingface.co/spaces/isayevlab/aimnet2-demo
"""
from __future__ import annotations
import html
import json
import os
import tempfile
import time
from pathlib import Path
# Disable torch.compile/dynamo — HF Spaces CPU has no compiler toolchain,
# and dynamo hangs during Hessian autograd backward passes.
os.environ["TORCHDYNAMO_DISABLE"] = "1"
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
torch.set_num_threads(2)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MAX_ATOMS = 200
MAX_ATOMS_OPT = 50
MAX_ATOMS_HESSIAN = 50
REQUEST_TIMEOUT = 90 # seconds cumulative per request
OPT_TIMEOUT = 85 # leave margin for Hessian
HARTREE_TO_EV = 27.211386024367243
EV_TO_KCAL = 23.06054783
ELEMENT_SYMBOLS = {
1: "H", 5: "B", 6: "C", 7: "N", 8: "O", 9: "F",
14: "Si", 15: "P", 16: "S", 17: "Cl",
33: "As", 34: "Se", 35: "Br", 46: "Pd", 53: "I",
}
SYMBOL_TO_NUM = {v: k for k, v in ELEMENT_SYMBOLS.items()}
# Atomic masses in amu (IUPAC 2021)
ATOMIC_MASSES = {
1: 1.008, 5: 10.81, 6: 12.011, 7: 14.007, 8: 15.999, 9: 18.998,
14: 28.085, 15: 30.974, 16: 32.06, 17: 35.45, 33: 74.922,
34: 78.971, 35: 79.904, 46: 106.42, 53: 126.904,
}
# Unit conversion: eigenvalue (eV/A^2/amu) -> s^-2
_EV_TO_J = 1.602176634e-19
_AMU_TO_KG = 1.66053906660e-27
_A_TO_M = 1e-10
_C_CM = 2.99792458e10 # speed of light in cm/s
_FREQ_CONV = _EV_TO_J / (_A_TO_M**2 * _AMU_TO_KG) # eV/(A^2*amu) -> s^-2
# ---------------------------------------------------------------------------
# Model loader (eager, singleton)
# ---------------------------------------------------------------------------
BASE_CALC = None
def get_base_calc():
"""Return shared AIMNet2Calculator singleton (thread-safe for read-only use)."""
global BASE_CALC
if BASE_CALC is None:
from aimnet.calculators import AIMNet2Calculator
BASE_CALC = AIMNet2Calculator("isayevlab/aimnet2-wb97m-d3", device="cpu")
return BASE_CALC
def make_ase_calc(charge: int = 0):
"""Create a fresh AIMNet2ASE wrapper per request (concurrency-safe)."""
from aimnet.calculators.aimnet2ase import AIMNet2ASE
return AIMNet2ASE(get_base_calc(), charge=charge)
# Eager-load model at import time (during Space startup, not first request)
try:
get_base_calc()
except Exception:
pass # Will fail on first request with a clear error instead
# ---------------------------------------------------------------------------
# Parsers
# ---------------------------------------------------------------------------
def parse_smiles(smiles: str) -> tuple[np.ndarray, np.ndarray, int]:
"""Parse SMILES -> (coords, numbers, formal_charge)."""
from rdkit import Chem
from rdkit.Chem import AllChem
mol = Chem.MolFromSmiles(smiles.strip())
if mol is None:
raise ValueError(f"Invalid SMILES: {smiles!r}")
formal_charge = Chem.GetFormalCharge(mol)
mol = Chem.AddHs(mol)
if AllChem.EmbedMolecule(mol, AllChem.ETKDGv3()) == -1:
raise ValueError("Failed to generate 3D coordinates. Try a different molecule.")
AllChem.MMFFOptimizeMolecule(mol)
conf = mol.GetConformer()
coords = np.array([conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())])
numbers = np.array([a.GetAtomicNum() for a in mol.GetAtoms()])
return coords, numbers, formal_charge
def parse_xyz(text: str) -> tuple[np.ndarray, np.ndarray]:
"""Parse XYZ format text -> (coords, numbers)."""
lines = [l.strip() for l in text.strip().splitlines()]
start = 2 if lines and lines[0].isdigit() else 0
coords_list, nums_list = [], []
for line in lines[start:]:
if not line:
continue
parts = line.split()
if len(parts) < 4:
continue
sym = parts[0].capitalize()
if sym not in SYMBOL_TO_NUM:
raise ValueError(f"Unknown element: {sym!r}")
nums_list.append(SYMBOL_TO_NUM[sym])
coords_list.append([float(parts[1]), float(parts[2]), float(parts[3])])
if not coords_list:
raise ValueError("No atoms found in XYZ input.")
return np.array(coords_list), np.array(nums_list)
def parse_pdb(text: str) -> tuple[np.ndarray, np.ndarray]:
"""Parse PDB format text -> (coords, numbers)."""
coords_list, nums_list = [], []
for line in text.splitlines():
if not line.startswith(("ATOM", "HETATM")):
continue
try:
x, y, z = float(line[30:38]), float(line[38:46]), float(line[46:54])
except ValueError:
continue
elem = line[76:78].strip() if len(line) >= 78 else ""
if not elem:
elem = line[12:16].strip().lstrip("0123456789")
elem = elem.capitalize()
if elem not in SYMBOL_TO_NUM:
raise ValueError(f"Unknown element in PDB: {elem!r}")
nums_list.append(SYMBOL_TO_NUM[elem])
coords_list.append([x, y, z])
if not coords_list:
raise ValueError("No ATOM/HETATM records found.")
return np.array(coords_list), np.array(nums_list)
def parse_input(text: str, fmt: str) -> tuple[np.ndarray, np.ndarray, str]:
"""Parse molecule input. Returns (coords, numbers, warning_str)."""
warning = ""
if fmt == "SMILES":
coords, numbers, _fc = parse_smiles(text)
elif fmt == "XYZ":
coords, numbers = parse_xyz(text)
elif fmt == "PDB":
coords, numbers = parse_pdb(text)
else:
raise ValueError(f"Unknown format: {fmt}")
return coords, numbers, warning
def handle_file_upload(file_obj) -> tuple[str, str]:
"""Process uploaded file. Returns (text_content, format_name).
Populates the text input and sets format radio.
"""
if file_obj is None:
return "", "SMILES"
path = Path(file_obj.name if hasattr(file_obj, "name") else file_obj)
suffix = path.suffix.lower()
text = path.read_text()
if suffix == ".xyz":
return text, "XYZ"
elif suffix == ".pdb":
return text, "PDB"
elif suffix in (".sdf", ".mol"):
from rdkit import Chem
from rdkit.Chem import AllChem
suppl = Chem.SDMolSupplier(str(path), removeHs=False)
mol = next(suppl, None)
if mol is None:
raise ValueError("Could not read SDF file.")
if mol.GetNumConformers() == 0:
mol = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol, AllChem.ETKDGv3())
# Convert to XYZ text
conf = mol.GetConformer()
n = mol.GetNumAtoms()
xyz_lines = [str(n), f"Converted from {path.name}"]
for i in range(n):
pos = conf.GetAtomPosition(i)
sym = mol.GetAtomWithIdx(i).GetSymbol()
xyz_lines.append(f"{sym} {pos.x:.6f} {pos.y:.6f} {pos.z:.6f}")
return "\n".join(xyz_lines), "XYZ"
else:
raise ValueError(f"Unsupported file type: {suffix}")
# ---------------------------------------------------------------------------
# 3D Viewer (iframe + 3Dmol.js)
# ---------------------------------------------------------------------------
def _charge_to_hex(q: float, qlim: float) -> str:
"""Map charge to color: red (negative) -> white (0) -> blue (positive)."""
t = np.clip((q + qlim) / (2 * qlim), 0, 1)
if t < 0.5:
s = t * 2
r, g, b = 1.0, s, s
else:
s = (t - 0.5) * 2
r, g, b = 1.0 - s, 1.0 - s, 1.0
return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
def build_viewer_html(
coords: np.ndarray,
numbers: np.ndarray,
charges: np.ndarray | None = None,
height: int = 420,
) -> str:
"""Build iframe HTML with 3Dmol.js viewer and CPK/charge toggle."""
n = len(numbers)
# Build XYZ string
xyz_lines = [str(n), "AIMNet2"]
for i in range(n):
sym = ELEMENT_SYMBOLS.get(int(numbers[i]), "X")
x, y, z = coords[i]
xyz_lines.append(f"{sym} {x:.6f} {y:.6f} {z:.6f}")
xyz_string = "\n".join(xyz_lines)
# Build per-atom charge color JS (only if charges and <= 100 atoms)
charge_js = ""
has_toggle = charges is not None and n <= 100
if has_toggle:
qlim = max(float(np.max(np.abs(charges))), 0.3)
charge_styles = []
for i in range(n):
c = _charge_to_hex(float(charges[i]), qlim)
charge_styles.append(
f'viewer.getModel().setAtomStyle({{index:{i}}},'
f'{{stick:{{radius:0.15}},sphere:{{scale:0.25,color:"{c}"}}}});'
)
charge_js = "\n".join(charge_styles)
toggle_btn = ""
toggle_fn = ""
if has_toggle:
toggle_btn = (
'<button id="toggle-btn" onclick="toggleColors()" '
'style="position:absolute;top:8px;right:8px;z-index:10;'
'padding:4px 10px;font-size:12px;cursor:pointer;'
'border:1px solid #ccc;border-radius:4px;background:#f8f8f8;">'
'Color by charge</button>'
)
toggle_fn = f"""
var cpkMode = true;
function setCPK() {{
viewer.setStyle({{}}, {{stick:{{radius:0.15}}, sphere:{{scale:0.25, colorscheme:"Jmol"}}}});
viewer.render();
}}
function setCharges() {{
{charge_js}
viewer.render();
}}
function toggleColors() {{
cpkMode = !cpkMode;
if (cpkMode) {{
setCPK();
document.getElementById("toggle-btn").textContent = "Color by charge";
}} else {{
setCharges();
document.getElementById("toggle-btn").textContent = "CPK colors";
}}
}}
"""
inner_html = f"""<!DOCTYPE html>
<html><head>
<meta charset="utf-8">
<style>
body {{ margin:0; overflow:hidden; font-family:sans-serif; }}
#viewer {{ width:100%; height:{height}px; position:relative; }}
#fallback {{ display:none; padding:20px; color:#888; text-align:center; }}
</style>
</head><body>
<div id="viewer"></div>
{toggle_btn}
<div id="fallback">3D viewer unavailable. Results are shown below.</div>
<script src="https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js"></script>
<script>
try {{
var xyz = {json.dumps(xyz_string)};
var viewer = $3Dmol.createViewer("viewer", {{backgroundColor:"white"}});
viewer.addModel(xyz, "xyz");
viewer.setStyle({{}}, {{stick:{{radius:0.15}}, sphere:{{scale:0.25, colorscheme:"Jmol"}}}});
viewer.zoomTo();
viewer.render();
{toggle_fn}
}} catch(e) {{
document.getElementById("viewer").style.display = "none";
document.getElementById("fallback").style.display = "block";
}}
</script>
</body></html>"""
escaped = html.escape(inner_html, quote=True)
return (
f'<iframe srcdoc="{escaped}" width="100%" height="{height + 30}" '
f'frameborder="0" sandbox="allow-scripts" '
f'style="border:1px solid #eee;border-radius:8px;"></iframe>'
)
# ---------------------------------------------------------------------------
# Frequency computation
# ---------------------------------------------------------------------------
def is_linear(coords: np.ndarray, numbers: np.ndarray, tol: float = 1e-3) -> bool:
"""Check if molecule is linear via moment of inertia tensor."""
masses = np.array([ATOMIC_MASSES.get(int(z), 1.0) for z in numbers])
com = np.average(coords, weights=masses, axis=0)
r = coords - com
I = np.zeros((3, 3))
for m, ri in zip(masses, r):
I += m * (np.dot(ri, ri) * np.eye(3) - np.outer(ri, ri))
eigvals = np.linalg.eigvalsh(I)
return eigvals[0] / max(eigvals[-1], 1e-30) < tol
def compute_frequencies(
hessian: np.ndarray,
numbers: np.ndarray,
coords: np.ndarray,
) -> tuple[np.ndarray, int]:
"""Compute vibrational frequencies from Hessian.
Parameters
----------
hessian : ndarray, shape (N,3,N,3) or (3N,3N)
Hessian in eV/A^2.
numbers : ndarray, shape (N,)
Atomic numbers.
coords : ndarray, shape (N,3)
Atomic positions (for linearity check).
Returns
-------
freqs_cm : ndarray
Vibrational frequencies in cm^-1. Negative = imaginary.
n_imag : int
Number of imaginary frequencies.
"""
n = len(numbers)
H = hessian.reshape(3 * n, 3 * n)
# Mass-weight
masses = np.array([ATOMIC_MASSES.get(int(z), 1.0) for z in numbers])
masses_3n = np.repeat(masses, 3)
H_mw = H / np.sqrt(np.outer(masses_3n, masses_3n))
H_mw = 0.5 * (H_mw + H_mw.T) # symmetrize
eigenvalues = np.linalg.eigvalsh(H_mw)
# Convert to cm^-1
freqs = (
np.sign(eigenvalues)
* np.sqrt(np.abs(eigenvalues) * _FREQ_CONV)
/ (2 * np.pi * _C_CM)
)
# Remove translation/rotation modes (count-based)
n_tr = 5 if is_linear(coords, numbers) else 6
sorted_idx = np.argsort(np.abs(freqs))
vib_idx = sorted_idx[n_tr:]
freqs_vib = np.sort(freqs[vib_idx])
n_imag = int(np.sum(freqs_vib < -10))
return freqs_vib, n_imag
# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------
def make_frequency_plot(freqs: np.ndarray) -> go.Figure:
"""Create Plotly stick spectrum of vibrational frequencies."""
real = freqs[freqs > 0]
fig = go.Figure()
if len(real) > 0:
fig.add_trace(go.Bar(
x=real, y=np.ones_like(real),
width=3, marker_color="steelblue",
hovertemplate="%{x:.1f} cm\u207b\u00b9<extra></extra>",
))
fig.update_layout(
xaxis_title="Frequency (cm\u207b\u00b9)",
yaxis_visible=False,
height=200, margin=dict(l=40, r=20, t=30, b=40),
title="Vibrational Spectrum",
showlegend=False,
)
return fig
def make_convergence_plot(trajectory: list[dict]) -> go.Figure:
"""Create dual-axis convergence plot (energy + max force vs step)."""
steps = [t["step"] for t in trajectory]
energies = [t["energy"] for t in trajectory]
fmaxes = [t["fmax"] for t in trajectory]
fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(
go.Scatter(x=steps, y=energies, name="Energy (eV)", mode="lines+markers",
marker=dict(size=4), line=dict(color="steelblue")),
secondary_y=False,
)
fig.add_trace(
go.Scatter(x=steps, y=fmaxes, name="Max |F| (eV/\u00c5)", mode="lines+markers",
marker=dict(size=4), line=dict(color="firebrick")),
secondary_y=True,
)
fig.update_xaxes(title_text="Step")
fig.update_yaxes(title_text="Energy (eV)", secondary_y=False)
fig.update_yaxes(title_text="Max |F| (eV/\u00c5)", secondary_y=True)
fig.update_layout(
height=280, margin=dict(l=60, r=60, t=30, b=40),
legend=dict(x=0.5, y=1.15, xanchor="center", orientation="h"),
)
return fig
# ---------------------------------------------------------------------------
# Geometry optimization
# ---------------------------------------------------------------------------
def run_optimization(
atoms,
max_steps: int,
fmax_target: float,
timeout: float = OPT_TIMEOUT,
) -> tuple[list[dict], bool, float]:
"""Run LBFGS optimization with timeout.
Returns (trajectory, converged, wall_time).
Reads from ASE cache to avoid double-computing.
"""
from ase.optimize import LBFGS
opt = LBFGS(atoms, logfile=None)
trajectory = []
t0 = time.time()
converged = False
for step in range(max_steps):
if time.time() - t0 > timeout:
break
opt.step()
e = float(atoms.calc.results["energy"])
f = atoms.calc.results["forces"]
fmax = float(np.max(np.linalg.norm(f, axis=1)))
trajectory.append({"step": step + 1, "energy": e, "fmax": fmax})
if fmax < fmax_target:
converged = True
break
return trajectory, converged, time.time() - t0
# ---------------------------------------------------------------------------
# Reproduction script generator
# ---------------------------------------------------------------------------
def _fmt_array(arr: np.ndarray, name: str) -> str:
"""Format numpy array as valid Python code."""
if arr.ndim == 1:
return f"{name} = {arr.tolist()!r}"
# 2D
rows = []
for row in arr:
rows.append(" [" + ", ".join(f"{v:.6f}" for v in row) + "],")
return f"{name} = np.array([\n" + "\n".join(rows) + "\n])"
def generate_script(
coords: np.ndarray,
numbers: np.ndarray,
charge: int,
task: str = "single_point",
max_steps: int = 30,
fmax: float = 0.05,
compute_hessian: bool = False,
) -> str:
"""Generate Python reproduction script."""
lines = [
"# AIMNet2 calculation",
"# Generated by https://huggingface.co/spaces/isayevlab/aimnet2-demo",
"from aimnet.calculators import AIMNet2Calculator",
"from aimnet.calculators.aimnet2ase import AIMNet2ASE",
"from ase import Atoms",
"import numpy as np",
"",
_fmt_array(coords, "coords"),
f"numbers = {numbers.tolist()!r}",
f"charge = {charge}",
"",
'calc = AIMNet2ASE(AIMNet2Calculator("isayevlab/aimnet2-wb97m-d3"), charge=charge)',
"atoms = Atoms(numbers=numbers, positions=coords)",
"atoms.calc = calc",
"",
]
if task == "optimize":
lines += [
"from ase.optimize import LBFGS",
f"opt = LBFGS(atoms, logfile='-')",
f"opt.run(fmax={fmax}, steps={max_steps})",
"",
"energy = atoms.get_potential_energy()",
'print(f"Optimized energy: {energy:.6f} eV")',
'print(f"Max force: {max(np.linalg.norm(atoms.get_forces(), axis=1)):.6f} eV/A")',
]
else:
lines += [
"energy = atoms.get_potential_energy()",
"forces = atoms.get_forces()",
'charges = atoms.calc.results["charges"]',
'print(f"Energy: {energy:.6f} eV")',
]
if compute_hessian:
lines += [
"",
"# Hessian & frequencies",
"base_calc = calc.base_calc",
'hess_result = base_calc({"coord": atoms.get_positions(), '
'"numbers": atoms.numbers, "charge": float(charge)}, hessian=True)',
'hessian = hess_result["hessian"].detach().cpu().numpy()',
"# Diagonalize mass-weighted Hessian for frequencies (see demo source for details)",
]
return "\n".join(lines)
# ---------------------------------------------------------------------------
# XYZ download helper
# ---------------------------------------------------------------------------
def write_xyz_file(coords: np.ndarray, numbers: np.ndarray,
charges: np.ndarray | None = None,
comment: str = "AIMNet2") -> str:
"""Write XYZ to a temp file and return the path."""
n = len(numbers)
lines = [str(n), comment]
for i in range(n):
sym = ELEMENT_SYMBOLS.get(int(numbers[i]), "X")
x, y, z = coords[i]
q_str = f" {charges[i]:+.4f}" if charges is not None else ""
lines.append(f"{sym:2s} {x:12.6f} {y:12.6f} {z:12.6f}{q_str}")
tmp = tempfile.NamedTemporaryFile(suffix=".xyz", delete=False, mode="w")
tmp.write("\n".join(lines))
tmp.close()
return tmp.name
# ---------------------------------------------------------------------------
# Tab 1: Single-point calculation
# ---------------------------------------------------------------------------
def predict(input_text, input_format, charge, compute_forces, compute_hessian):
"""Run single-point calculation. Returns (markdown, viewer_html, freq_plot, xyz_file, script)."""
charge = int(charge)
empty = ("", "", None, None, "")
# Parse
try:
coords, numbers, warning = parse_input(input_text, input_format)
except Exception as e:
return (f"**Parse error:** {e}", *empty[1:])
n = len(numbers)
if n > MAX_ATOMS:
return (f"**Error:** {n} atoms exceeds limit of {MAX_ATOMS}.", *empty[1:])
if compute_hessian and n > MAX_ATOMS_HESSIAN:
return (f"**Error:** Hessian limited to {MAX_ATOMS_HESSIAN} atoms ({n} given).", *empty[1:])
# Validate elements
unsupported = sorted({int(z) for z in numbers} - set(ELEMENT_SYMBOLS))
if unsupported:
return (f"**Error:** Unsupported elements: {unsupported}", *empty[1:])
# SMILES charge validation
smiles_warn = ""
if input_format == "SMILES":
_, _, fc = parse_smiles(input_text) # already parsed, just get charge
if fc != charge:
smiles_warn = (
f"\n> **Warning:** SMILES formal charge ({fc:+d}) != "
f"supplied charge ({charge:+d}). Using supplied charge.\n"
)
# Calculate
try:
ase_calc = make_ase_calc(charge)
from ase import Atoms
symbols = [ELEMENT_SYMBOLS[int(z)] for z in numbers]
atoms = Atoms(symbols=symbols, positions=coords)
atoms.calc = ase_calc
atoms.get_potential_energy()
energy_ev = float(ase_calc.results["energy"])
charges_arr = ase_calc.results.get("charges")
if not np.isfinite(energy_ev):
return ("**Error:** Model produced NaN/Inf. Molecule may be outside training domain.", *empty[1:])
forces_arr = None
if compute_forces:
atoms.get_forces()
forces_arr = ase_calc.results["forces"]
hessian_arr = None
freqs = None
n_imag = 0
if compute_hessian:
data = {"coord": coords, "numbers": numbers, "charge": float(charge)}
hess_result = get_base_calc()(data, hessian=True)
hessian_arr = hess_result["hessian"].detach().cpu().numpy()
freqs, n_imag = compute_frequencies(hessian_arr, numbers, coords)
except Exception as e:
import traceback
return (f"**Calculation error:** {e}\n```\n{traceback.format_exc()}\n```", *empty[1:])
# Build outputs
viewer_html = build_viewer_html(coords, numbers, charges_arr)
# Results markdown
energy_kcal = energy_ev * EV_TO_KCAL
energy_ha = energy_ev / HARTREE_TO_EV
md = []
md.append("## AIMNet2 Results\n")
if smiles_warn:
md.append(smiles_warn)
md.append(f"**Atoms:** {n} | **Charge:** {charge:+d}\n")
md.append("### Energy\n| Unit | Value |\n|------|------:|")
md.append(f"| eV | {energy_ev:.6f} |")
md.append(f"| kcal/mol | {energy_kcal:.4f} |")
md.append(f"| Hartree | {energy_ha:.8f} |\n")
if charges_arr is not None:
md.append("### Partial Charges (e)\n| # | Elem | Charge |\n|--:|:----:|-------:|")
for i, (z, q) in enumerate(zip(numbers, charges_arr)):
sym = ELEMENT_SYMBOLS.get(int(z), "?")
md.append(f"| {i+1} | {sym} | {q:+.4f} |")
md.append(f"\n*Sum: {float(np.sum(charges_arr)):+.4f} e*\n")
if forces_arr is not None:
max_f = float(np.max(np.linalg.norm(forces_arr, axis=1)))
rms_f = float(np.sqrt(np.mean(forces_arr**2)))
md.append("### Forces (eV/A)\n| Metric | Value |\n|--------|------:|")
md.append(f"| Max |F| | {max_f:.6f} |")
md.append(f"| RMS | {rms_f:.6f} |")
if input_format == "SMILES":
md.append("\n> *Geometry from MMFF, not AIMNet2-optimized.*\n")
freq_plot = None
if freqs is not None:
real_f = freqs[freqs > 0]
imag_f = freqs[freqs < 0]
md.append("### Vibrational Frequencies\n")
if max_f > 0.05 if forces_arr is not None else True:
md.append("> *Frequencies at non-stationary point. Low modes may be unreliable.*\n")
if n_imag > 0:
md.append(f"> **{n_imag} imaginary frequency(ies)** -- not a true minimum.\n")
if len(real_f) > 0:
md.append("```")
for j, f in enumerate(real_f):
md.append(f" {j+1:3d}: {f:10.2f} cm-1")
md.append("```")
if len(imag_f) > 0:
md.append("\nImaginary:\n```")
for j, f in enumerate(imag_f):
md.append(f" {j+1:3d}: {abs(f):10.2f}i cm-1")
md.append("```")
freq_plot = make_frequency_plot(freqs)
md.append("\n---")
md.append("*AIMNet2 wB97M-D3 | [Model](https://huggingface.co/isayevlab/aimnet2-wb97m-d3) | [Paper](https://doi.org/10.1039/D4SC08572H)*")
xyz_file = write_xyz_file(coords, numbers, charges_arr,
comment=f"Energy: {energy_ev:.6f} eV")
script = generate_script(coords, numbers, charge, "single_point",
compute_hessian=compute_hessian)
return "\n".join(md), viewer_html, freq_plot, xyz_file, script
# ---------------------------------------------------------------------------
# Tab 2: Geometry optimization
# ---------------------------------------------------------------------------
def optimize(input_text, input_format, charge, max_steps, fmax_target,
compute_freqs):
"""Run geometry optimization. Returns (md, viewer_html, conv_plot, freq_plot, xyz_file, script)."""
charge = int(charge)
max_steps = int(max_steps)
fmax_target = float(fmax_target)
empty = ("", "", None, None, None, "")
# Auto-tighten fmax when frequencies requested
if compute_freqs and fmax_target > 0.02:
fmax_target = 0.02
# Validate fmax
if not 0.01 <= fmax_target <= 1.0:
return ("**Error:** fmax must be between 0.01 and 1.0 eV/A.", *empty[1:])
# Parse
try:
coords, numbers, _ = parse_input(input_text, input_format)
except Exception as e:
return (f"**Parse error:** {e}", *empty[1:])
n = len(numbers)
if n > MAX_ATOMS_OPT:
return (f"**Error:** Optimization limited to {MAX_ATOMS_OPT} atoms ({n} given).", *empty[1:])
unsupported = sorted({int(z) for z in numbers} - set(ELEMENT_SYMBOLS))
if unsupported:
return (f"**Error:** Unsupported elements: {unsupported}", *empty[1:])
# Optimize
try:
ase_calc = make_ase_calc(charge)
from ase import Atoms
symbols = [ELEMENT_SYMBOLS[int(z)] for z in numbers]
atoms = Atoms(symbols=symbols, positions=coords)
atoms.calc = ase_calc
# Initial energy/forces (must request forces explicitly)
atoms.get_forces() # triggers full calc including energy
e0 = float(ase_calc.results["energy"])
f0 = ase_calc.results["forces"]
fmax0 = float(np.max(np.linalg.norm(f0, axis=1)))
trajectory, converged, wall_time = run_optimization(
atoms, max_steps, fmax_target
)
opt_coords = atoms.get_positions()
e_final = trajectory[-1]["energy"] if trajectory else e0
fmax_final = trajectory[-1]["fmax"] if trajectory else fmax0
charges_arr = ase_calc.results.get("charges")
if not np.isfinite(e_final):
return ("**Error:** Model produced NaN/Inf during optimization.", *empty[1:])
# Frequencies at optimized geometry
freqs = None
n_imag = 0
if compute_freqs:
data = {
"coord": opt_coords,
"numbers": atoms.numbers,
"charge": float(charge),
}
hess_result = get_base_calc()(data, hessian=True)
hessian = hess_result["hessian"].detach().cpu().numpy()
freqs, n_imag = compute_frequencies(hessian, atoms.numbers, opt_coords)
except Exception as e:
import traceback
return (f"**Calculation error:** {e}\n```\n{traceback.format_exc()}\n```", *empty[1:])
# Build outputs
viewer_html = build_viewer_html(opt_coords, numbers, charges_arr)
conv_plot = make_convergence_plot(trajectory) if trajectory else None
md = []
md.append("## Optimization Results\n")
status = "Converged" if converged else "Not converged"
if not converged:
md.append(f"> **{status}** after {len(trajectory)} steps / {wall_time:.1f}s "
f"(final fmax: {fmax_final:.4f} eV/A)\n")
else:
md.append(f"**{status}** in {len(trajectory)} steps ({wall_time:.1f}s)\n")
md.append("| Property | Initial | Final |")
md.append("|----------|--------:|------:|")
md.append(f"| Energy (eV) | {e0:.6f} | {e_final:.6f} |")
md.append(f"| Energy (kcal/mol) | {e0*EV_TO_KCAL:.4f} | {e_final*EV_TO_KCAL:.4f} |")
md.append(f"| Max |F| (eV/A) | {fmax0:.6f} | {fmax_final:.6f} |")
md.append(f"| dE (eV) | | {e_final - e0:.6f} |")
md.append("")
if charges_arr is not None:
md.append("### Partial Charges (e)\n| # | Elem | Charge |\n|--:|:----:|-------:|")
for i, (z, q) in enumerate(zip(numbers, charges_arr)):
sym = ELEMENT_SYMBOLS.get(int(z), "?")
md.append(f"| {i+1} | {sym} | {q:+.4f} |")
md.append(f"\n*Sum: {float(np.sum(charges_arr)):+.4f} e*\n")
freq_plot = None
if freqs is not None:
real_f = freqs[freqs > 0]
imag_f = freqs[freqs < 0]
md.append("### Vibrational Frequencies\n")
if n_imag > 0:
md.append(f"> **{n_imag} imaginary frequency(ies)** -- not a true minimum.\n")
if len(real_f) > 0:
md.append("```")
for j, f in enumerate(real_f):
md.append(f" {j+1:3d}: {f:10.2f} cm-1")
md.append("```")
if len(imag_f) > 0:
md.append("\nImaginary:\n```")
for j, f in enumerate(imag_f):
md.append(f" {j+1:3d}: {abs(f):10.2f}i cm-1")
md.append("```")
freq_plot = make_frequency_plot(freqs)
md.append("\n---")
md.append("*AIMNet2 wB97M-D3 | [Model](https://huggingface.co/isayevlab/aimnet2-wb97m-d3)*")
xyz_file = write_xyz_file(opt_coords, numbers, charges_arr,
comment=f"Optimized, E={e_final:.6f} eV, fmax={fmax_final:.6f}")
script = generate_script(coords, numbers, charge, "optimize",
max_steps=max_steps, fmax=fmax_target,
compute_hessian=compute_freqs)
return "\n".join(md), viewer_html, conv_plot, freq_plot, xyz_file, script
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
CALC_EXAMPLES = [
["CCO", "SMILES", 0, True, False],
["c1ccccc1", "SMILES", 0, True, False],
["CC(=O)O", "SMILES", 0, True, False],
["[NH4+]", "SMILES", 1, True, False],
["CC(=O)[O-]", "SMILES", -1, True, False],
["O=C(O)c1ccccc1", "SMILES", 0, True, False],
["O", "SMILES", 0, True, True],
]
OPT_EXAMPLES = [
["CCO", "SMILES", 0, 30, 0.05, False],
["O", "SMILES", 0, 30, 0.05, True],
]
VIEWER_PLACEHOLDER = """<div style="
height:420px;display:flex;flex-direction:column;
align-items:center;justify-content:center;
background:linear-gradient(135deg,#f0f4ff 0%,#e8eeff 100%);
border-radius:12px;border:1px solid #dbeafe;
font-family:'Inter',system-ui,sans-serif;
">
<svg width="48" height="48" viewBox="0 0 24 24" fill="none" stroke="#94a3b8" stroke-width="1.5">
<path d="M12 2L2 7l10 5 10-5-10-5z"/><path d="M2 17l10 5 10-5"/>
<path d="M2 12l10 5 10-5"/>
</svg>
<p style="color:#94a3b8;margin-top:12px;font-size:14px;">
Enter a molecule and click Calculate or Optimize
</p>
</div>"""
# Custom CSS for clean, modern look
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');
.gradio-container {
max-width: 1200px !important;
font-family: 'Inter', system-ui, -apple-system, sans-serif !important;
}
h1 {
font-weight: 700 !important;
letter-spacing: -0.02em !important;
color: #1e3a5f !important;
}
.header-subtitle {
color: #64748b !important;
font-size: 15px !important;
line-height: 1.6 !important;
margin-bottom: 4px !important;
}
.header-links a {
color: #3b82f6 !important;
text-decoration: none !important;
font-size: 13px !important;
}
.header-links a:hover { text-decoration: underline !important; }
"""
# Theme
THEME = gr.themes.Base(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.slate,
neutral_hue=gr.themes.colors.slate,
font=gr.themes.GoogleFont("Inter"),
font_mono=gr.themes.GoogleFont("Fira Code"),
radius_size=gr.themes.sizes.radius_md,
).set(
body_background_fill="#fafbfc",
block_background_fill="white",
block_border_width="1px",
block_border_color="#e2e8f0",
block_shadow="0 1px 3px rgba(0,0,0,0.04)",
block_label_text_size="13px",
block_label_text_weight="500",
block_label_text_color="#475569",
button_primary_background_fill="#2563eb",
button_primary_background_fill_hover="#1d4ed8",
button_primary_text_color="white",
input_border_color="#e2e8f0",
input_background_fill="white",
checkbox_label_text_size="13px",
)
with gr.Blocks(title="AIMNet2 Demo", theme=THEME, css=CUSTOM_CSS) as demo:
# --- Header ---
gr.Markdown("# AIMNet2")
gr.Markdown(
'<span class="header-subtitle">'
"Fast neural network interatomic potential &mdash; "
"energy, forces, partial charges, geometry optimization, vibrational frequencies. "
"Atoms colored by predicted charge."
"</span>",
)
gr.Markdown(
'<span class="header-links">'
'[Model Card](https://huggingface.co/isayevlab/aimnet2-wb97m-d3) &nbsp;&bull;&nbsp; '
'[Paper](https://doi.org/10.1039/D4SC08572H) &nbsp;&bull;&nbsp; '
'[GitHub](https://github.com/isayevlab/aimnetcentral) &nbsp;&bull;&nbsp; '
'`pip install "aimnet[hf]"`'
"</span>",
)
# --- Molecule Input (compact row) ---
with gr.Row(equal_height=True):
with gr.Column(scale=2, min_width=280):
input_text = gr.Textbox(
lines=4, label="Molecule",
placeholder="Enter SMILES (e.g. CCO), paste XYZ coordinates, or PDB block...",
)
with gr.Column(scale=1, min_width=180):
input_format = gr.Radio(
["SMILES", "XYZ", "PDB"], value="SMILES", label="Format",
)
with gr.Row():
charge_input = gr.Number(value=0, precision=0, label="Charge", scale=1)
file_upload = gr.File(
label="Upload",
file_types=[".xyz", ".pdb", ".sdf", ".mol"],
scale=1,
)
# File upload handler
def on_file_upload(file_obj):
if file_obj is None:
return gr.update(), gr.update()
try:
text, fmt = handle_file_upload(file_obj)
gr.Info(f"Loaded file ({fmt} format)")
return gr.update(value=text), gr.update(value=fmt)
except Exception as e:
gr.Warning(f"File upload failed: {e}")
return gr.update(), gr.update()
file_upload.change(
on_file_upload, inputs=[file_upload], outputs=[input_text, input_format]
)
# --- Tabs ---
with gr.Tabs() as tabs:
# ===== Tab 1: Calculate =====
with gr.TabItem("Single Point", id=0):
with gr.Row(equal_height=False):
# Left: controls
with gr.Column(scale=1, min_width=220):
calc_forces = gr.Checkbox(value=True, label="Forces")
calc_hessian = gr.Checkbox(
value=False, label="Hessian & Frequencies"
)
calc_btn = gr.Button(
"Calculate", variant="primary", size="lg",
)
gr.Markdown("##### Try an example")
gr.Examples(
examples=CALC_EXAMPLES,
inputs=[input_text, input_format, charge_input,
calc_forces, calc_hessian],
label="",
)
# Right: results
with gr.Column(scale=3, min_width=500):
calc_viewer = gr.HTML(value=VIEWER_PLACEHOLDER)
calc_results = gr.Markdown()
calc_freq_plot = gr.Plot(visible=False)
with gr.Row():
with gr.Accordion("Download XYZ", open=False):
calc_xyz = gr.File(interactive=False)
with gr.Accordion("Python code", open=False):
calc_script = gr.Code(language="python")
def calc_wrapper(text, fmt, charge, forces, hessian):
md, viewer, fplot, xyz, script = predict(
text, fmt, charge, forces, hessian
)
return (
md,
viewer or VIEWER_PLACEHOLDER,
gr.update(value=fplot, visible=fplot is not None),
xyz,
script,
)
calc_btn.click(
calc_wrapper,
inputs=[input_text, input_format, charge_input,
calc_forces, calc_hessian],
outputs=[calc_results, calc_viewer, calc_freq_plot,
calc_xyz, calc_script],
)
# ===== Tab 2: Optimize =====
with gr.TabItem("Optimize", id=1):
with gr.Row(equal_height=False):
# Left: controls
with gr.Column(scale=1, min_width=220):
opt_steps = gr.Slider(
10, 50, value=30, step=1, label="Max steps",
)
opt_fmax = gr.Number(
value=0.05, label="fmax (eV/A)",
minimum=0.01, maximum=1.0,
)
opt_freqs = gr.Checkbox(
value=False, label="Frequencies at minimum",
)
opt_btn = gr.Button(
"Optimize", variant="primary", size="lg",
)
gr.Markdown("##### Try an example")
gr.Examples(
examples=OPT_EXAMPLES,
inputs=[input_text, input_format, charge_input,
opt_steps, opt_fmax, opt_freqs],
label="",
)
# Right: results
with gr.Column(scale=3, min_width=500):
opt_viewer = gr.HTML(value=VIEWER_PLACEHOLDER)
opt_conv_plot = gr.Plot(label="Convergence")
opt_results = gr.Markdown()
opt_freq_plot = gr.Plot(visible=False)
with gr.Row():
with gr.Accordion("Download XYZ", open=False):
opt_xyz = gr.File(interactive=False)
with gr.Accordion("Python code", open=False):
opt_script = gr.Code(language="python")
def opt_wrapper(text, fmt, charge, steps, fmax, freqs):
md, viewer, conv, fplot, xyz, script = optimize(
text, fmt, charge, steps, fmax, freqs
)
return (
md,
viewer or VIEWER_PLACEHOLDER,
conv,
gr.update(value=fplot, visible=fplot is not None),
xyz,
script,
)
opt_btn.click(
opt_wrapper,
inputs=[input_text, input_format, charge_input,
opt_steps, opt_fmax, opt_freqs],
outputs=[opt_results, opt_viewer, opt_conv_plot,
opt_freq_plot, opt_xyz, opt_script],
)
if __name__ == "__main__":
demo.launch()