Spaces:
Build error
Build error
| """AIMNet2 Interactive Demo v2. | |
| 3D visualization, geometry optimization, vibrational analysis, charge coloring. | |
| https://huggingface.co/spaces/isayevlab/aimnet2-demo | |
| """ | |
| from __future__ import annotations | |
| import html | |
| import json | |
| import os | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| # Disable torch.compile/dynamo — HF Spaces CPU has no compiler toolchain, | |
| # and dynamo hangs during Hessian autograd backward passes. | |
| os.environ["TORCHDYNAMO_DISABLE"] = "1" | |
| import gradio as gr | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import torch | |
| from plotly.subplots import make_subplots | |
| torch.set_num_threads(2) | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| MAX_ATOMS = 200 | |
| MAX_ATOMS_OPT = 50 | |
| MAX_ATOMS_HESSIAN = 50 | |
| REQUEST_TIMEOUT = 90 # seconds cumulative per request | |
| OPT_TIMEOUT = 85 # leave margin for Hessian | |
| HARTREE_TO_EV = 27.211386024367243 | |
| EV_TO_KCAL = 23.06054783 | |
| ELEMENT_SYMBOLS = { | |
| 1: "H", 5: "B", 6: "C", 7: "N", 8: "O", 9: "F", | |
| 14: "Si", 15: "P", 16: "S", 17: "Cl", | |
| 33: "As", 34: "Se", 35: "Br", 46: "Pd", 53: "I", | |
| } | |
| SYMBOL_TO_NUM = {v: k for k, v in ELEMENT_SYMBOLS.items()} | |
| # Atomic masses in amu (IUPAC 2021) | |
| ATOMIC_MASSES = { | |
| 1: 1.008, 5: 10.81, 6: 12.011, 7: 14.007, 8: 15.999, 9: 18.998, | |
| 14: 28.085, 15: 30.974, 16: 32.06, 17: 35.45, 33: 74.922, | |
| 34: 78.971, 35: 79.904, 46: 106.42, 53: 126.904, | |
| } | |
| # Unit conversion: eigenvalue (eV/A^2/amu) -> s^-2 | |
| _EV_TO_J = 1.602176634e-19 | |
| _AMU_TO_KG = 1.66053906660e-27 | |
| _A_TO_M = 1e-10 | |
| _C_CM = 2.99792458e10 # speed of light in cm/s | |
| _FREQ_CONV = _EV_TO_J / (_A_TO_M**2 * _AMU_TO_KG) # eV/(A^2*amu) -> s^-2 | |
| # --------------------------------------------------------------------------- | |
| # Model loader (eager, singleton) | |
| # --------------------------------------------------------------------------- | |
| BASE_CALC = None | |
| def get_base_calc(): | |
| """Return shared AIMNet2Calculator singleton (thread-safe for read-only use).""" | |
| global BASE_CALC | |
| if BASE_CALC is None: | |
| from aimnet.calculators import AIMNet2Calculator | |
| BASE_CALC = AIMNet2Calculator("isayevlab/aimnet2-wb97m-d3", device="cpu") | |
| return BASE_CALC | |
| def make_ase_calc(charge: int = 0): | |
| """Create a fresh AIMNet2ASE wrapper per request (concurrency-safe).""" | |
| from aimnet.calculators.aimnet2ase import AIMNet2ASE | |
| return AIMNet2ASE(get_base_calc(), charge=charge) | |
| # Eager-load model at import time (during Space startup, not first request) | |
| try: | |
| get_base_calc() | |
| except Exception: | |
| pass # Will fail on first request with a clear error instead | |
| # --------------------------------------------------------------------------- | |
| # Parsers | |
| # --------------------------------------------------------------------------- | |
| def parse_smiles(smiles: str) -> tuple[np.ndarray, np.ndarray, int]: | |
| """Parse SMILES -> (coords, numbers, formal_charge).""" | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem | |
| mol = Chem.MolFromSmiles(smiles.strip()) | |
| if mol is None: | |
| raise ValueError(f"Invalid SMILES: {smiles!r}") | |
| formal_charge = Chem.GetFormalCharge(mol) | |
| mol = Chem.AddHs(mol) | |
| if AllChem.EmbedMolecule(mol, AllChem.ETKDGv3()) == -1: | |
| raise ValueError("Failed to generate 3D coordinates. Try a different molecule.") | |
| AllChem.MMFFOptimizeMolecule(mol) | |
| conf = mol.GetConformer() | |
| coords = np.array([conf.GetAtomPosition(i) for i in range(mol.GetNumAtoms())]) | |
| numbers = np.array([a.GetAtomicNum() for a in mol.GetAtoms()]) | |
| return coords, numbers, formal_charge | |
| def parse_xyz(text: str) -> tuple[np.ndarray, np.ndarray]: | |
| """Parse XYZ format text -> (coords, numbers).""" | |
| lines = [l.strip() for l in text.strip().splitlines()] | |
| start = 2 if lines and lines[0].isdigit() else 0 | |
| coords_list, nums_list = [], [] | |
| for line in lines[start:]: | |
| if not line: | |
| continue | |
| parts = line.split() | |
| if len(parts) < 4: | |
| continue | |
| sym = parts[0].capitalize() | |
| if sym not in SYMBOL_TO_NUM: | |
| raise ValueError(f"Unknown element: {sym!r}") | |
| nums_list.append(SYMBOL_TO_NUM[sym]) | |
| coords_list.append([float(parts[1]), float(parts[2]), float(parts[3])]) | |
| if not coords_list: | |
| raise ValueError("No atoms found in XYZ input.") | |
| return np.array(coords_list), np.array(nums_list) | |
| def parse_pdb(text: str) -> tuple[np.ndarray, np.ndarray]: | |
| """Parse PDB format text -> (coords, numbers).""" | |
| coords_list, nums_list = [], [] | |
| for line in text.splitlines(): | |
| if not line.startswith(("ATOM", "HETATM")): | |
| continue | |
| try: | |
| x, y, z = float(line[30:38]), float(line[38:46]), float(line[46:54]) | |
| except ValueError: | |
| continue | |
| elem = line[76:78].strip() if len(line) >= 78 else "" | |
| if not elem: | |
| elem = line[12:16].strip().lstrip("0123456789") | |
| elem = elem.capitalize() | |
| if elem not in SYMBOL_TO_NUM: | |
| raise ValueError(f"Unknown element in PDB: {elem!r}") | |
| nums_list.append(SYMBOL_TO_NUM[elem]) | |
| coords_list.append([x, y, z]) | |
| if not coords_list: | |
| raise ValueError("No ATOM/HETATM records found.") | |
| return np.array(coords_list), np.array(nums_list) | |
| def parse_input(text: str, fmt: str) -> tuple[np.ndarray, np.ndarray, str]: | |
| """Parse molecule input. Returns (coords, numbers, warning_str).""" | |
| warning = "" | |
| if fmt == "SMILES": | |
| coords, numbers, _fc = parse_smiles(text) | |
| elif fmt == "XYZ": | |
| coords, numbers = parse_xyz(text) | |
| elif fmt == "PDB": | |
| coords, numbers = parse_pdb(text) | |
| else: | |
| raise ValueError(f"Unknown format: {fmt}") | |
| return coords, numbers, warning | |
| def handle_file_upload(file_obj) -> tuple[str, str]: | |
| """Process uploaded file. Returns (text_content, format_name). | |
| Populates the text input and sets format radio. | |
| """ | |
| if file_obj is None: | |
| return "", "SMILES" | |
| path = Path(file_obj.name if hasattr(file_obj, "name") else file_obj) | |
| suffix = path.suffix.lower() | |
| text = path.read_text() | |
| if suffix == ".xyz": | |
| return text, "XYZ" | |
| elif suffix == ".pdb": | |
| return text, "PDB" | |
| elif suffix in (".sdf", ".mol"): | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem | |
| suppl = Chem.SDMolSupplier(str(path), removeHs=False) | |
| mol = next(suppl, None) | |
| if mol is None: | |
| raise ValueError("Could not read SDF file.") | |
| if mol.GetNumConformers() == 0: | |
| mol = Chem.AddHs(mol) | |
| AllChem.EmbedMolecule(mol, AllChem.ETKDGv3()) | |
| # Convert to XYZ text | |
| conf = mol.GetConformer() | |
| n = mol.GetNumAtoms() | |
| xyz_lines = [str(n), f"Converted from {path.name}"] | |
| for i in range(n): | |
| pos = conf.GetAtomPosition(i) | |
| sym = mol.GetAtomWithIdx(i).GetSymbol() | |
| xyz_lines.append(f"{sym} {pos.x:.6f} {pos.y:.6f} {pos.z:.6f}") | |
| return "\n".join(xyz_lines), "XYZ" | |
| else: | |
| raise ValueError(f"Unsupported file type: {suffix}") | |
| # --------------------------------------------------------------------------- | |
| # 3D Viewer (iframe + 3Dmol.js) | |
| # --------------------------------------------------------------------------- | |
| def _charge_to_hex(q: float, qlim: float) -> str: | |
| """Map charge to color: red (negative) -> white (0) -> blue (positive).""" | |
| t = np.clip((q + qlim) / (2 * qlim), 0, 1) | |
| if t < 0.5: | |
| s = t * 2 | |
| r, g, b = 1.0, s, s | |
| else: | |
| s = (t - 0.5) * 2 | |
| r, g, b = 1.0 - s, 1.0 - s, 1.0 | |
| return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}" | |
| def build_viewer_html( | |
| coords: np.ndarray, | |
| numbers: np.ndarray, | |
| charges: np.ndarray | None = None, | |
| height: int = 420, | |
| ) -> str: | |
| """Build iframe HTML with 3Dmol.js viewer and CPK/charge toggle.""" | |
| n = len(numbers) | |
| # Build XYZ string | |
| xyz_lines = [str(n), "AIMNet2"] | |
| for i in range(n): | |
| sym = ELEMENT_SYMBOLS.get(int(numbers[i]), "X") | |
| x, y, z = coords[i] | |
| xyz_lines.append(f"{sym} {x:.6f} {y:.6f} {z:.6f}") | |
| xyz_string = "\n".join(xyz_lines) | |
| # Build per-atom charge color JS (only if charges and <= 100 atoms) | |
| charge_js = "" | |
| has_toggle = charges is not None and n <= 100 | |
| if has_toggle: | |
| qlim = max(float(np.max(np.abs(charges))), 0.3) | |
| charge_styles = [] | |
| for i in range(n): | |
| c = _charge_to_hex(float(charges[i]), qlim) | |
| charge_styles.append( | |
| f'viewer.getModel().setAtomStyle({{index:{i}}},' | |
| f'{{stick:{{radius:0.15}},sphere:{{scale:0.25,color:"{c}"}}}});' | |
| ) | |
| charge_js = "\n".join(charge_styles) | |
| toggle_btn = "" | |
| toggle_fn = "" | |
| if has_toggle: | |
| toggle_btn = ( | |
| '<button id="toggle-btn" onclick="toggleColors()" ' | |
| 'style="position:absolute;top:8px;right:8px;z-index:10;' | |
| 'padding:4px 10px;font-size:12px;cursor:pointer;' | |
| 'border:1px solid #ccc;border-radius:4px;background:#f8f8f8;">' | |
| 'Color by charge</button>' | |
| ) | |
| toggle_fn = f""" | |
| var cpkMode = true; | |
| function setCPK() {{ | |
| viewer.setStyle({{}}, {{stick:{{radius:0.15}}, sphere:{{scale:0.25, colorscheme:"Jmol"}}}}); | |
| viewer.render(); | |
| }} | |
| function setCharges() {{ | |
| {charge_js} | |
| viewer.render(); | |
| }} | |
| function toggleColors() {{ | |
| cpkMode = !cpkMode; | |
| if (cpkMode) {{ | |
| setCPK(); | |
| document.getElementById("toggle-btn").textContent = "Color by charge"; | |
| }} else {{ | |
| setCharges(); | |
| document.getElementById("toggle-btn").textContent = "CPK colors"; | |
| }} | |
| }} | |
| """ | |
| inner_html = f"""<!DOCTYPE html> | |
| <html><head> | |
| <meta charset="utf-8"> | |
| <style> | |
| body {{ margin:0; overflow:hidden; font-family:sans-serif; }} | |
| #viewer {{ width:100%; height:{height}px; position:relative; }} | |
| #fallback {{ display:none; padding:20px; color:#888; text-align:center; }} | |
| </style> | |
| </head><body> | |
| <div id="viewer"></div> | |
| {toggle_btn} | |
| <div id="fallback">3D viewer unavailable. Results are shown below.</div> | |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.2/3Dmol-min.js"></script> | |
| <script> | |
| try {{ | |
| var xyz = {json.dumps(xyz_string)}; | |
| var viewer = $3Dmol.createViewer("viewer", {{backgroundColor:"white"}}); | |
| viewer.addModel(xyz, "xyz"); | |
| viewer.setStyle({{}}, {{stick:{{radius:0.15}}, sphere:{{scale:0.25, colorscheme:"Jmol"}}}}); | |
| viewer.zoomTo(); | |
| viewer.render(); | |
| {toggle_fn} | |
| }} catch(e) {{ | |
| document.getElementById("viewer").style.display = "none"; | |
| document.getElementById("fallback").style.display = "block"; | |
| }} | |
| </script> | |
| </body></html>""" | |
| escaped = html.escape(inner_html, quote=True) | |
| return ( | |
| f'<iframe srcdoc="{escaped}" width="100%" height="{height + 30}" ' | |
| f'frameborder="0" sandbox="allow-scripts" ' | |
| f'style="border:1px solid #eee;border-radius:8px;"></iframe>' | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Frequency computation | |
| # --------------------------------------------------------------------------- | |
| def is_linear(coords: np.ndarray, numbers: np.ndarray, tol: float = 1e-3) -> bool: | |
| """Check if molecule is linear via moment of inertia tensor.""" | |
| masses = np.array([ATOMIC_MASSES.get(int(z), 1.0) for z in numbers]) | |
| com = np.average(coords, weights=masses, axis=0) | |
| r = coords - com | |
| I = np.zeros((3, 3)) | |
| for m, ri in zip(masses, r): | |
| I += m * (np.dot(ri, ri) * np.eye(3) - np.outer(ri, ri)) | |
| eigvals = np.linalg.eigvalsh(I) | |
| return eigvals[0] / max(eigvals[-1], 1e-30) < tol | |
| def compute_frequencies( | |
| hessian: np.ndarray, | |
| numbers: np.ndarray, | |
| coords: np.ndarray, | |
| ) -> tuple[np.ndarray, int]: | |
| """Compute vibrational frequencies from Hessian. | |
| Parameters | |
| ---------- | |
| hessian : ndarray, shape (N,3,N,3) or (3N,3N) | |
| Hessian in eV/A^2. | |
| numbers : ndarray, shape (N,) | |
| Atomic numbers. | |
| coords : ndarray, shape (N,3) | |
| Atomic positions (for linearity check). | |
| Returns | |
| ------- | |
| freqs_cm : ndarray | |
| Vibrational frequencies in cm^-1. Negative = imaginary. | |
| n_imag : int | |
| Number of imaginary frequencies. | |
| """ | |
| n = len(numbers) | |
| H = hessian.reshape(3 * n, 3 * n) | |
| # Mass-weight | |
| masses = np.array([ATOMIC_MASSES.get(int(z), 1.0) for z in numbers]) | |
| masses_3n = np.repeat(masses, 3) | |
| H_mw = H / np.sqrt(np.outer(masses_3n, masses_3n)) | |
| H_mw = 0.5 * (H_mw + H_mw.T) # symmetrize | |
| eigenvalues = np.linalg.eigvalsh(H_mw) | |
| # Convert to cm^-1 | |
| freqs = ( | |
| np.sign(eigenvalues) | |
| * np.sqrt(np.abs(eigenvalues) * _FREQ_CONV) | |
| / (2 * np.pi * _C_CM) | |
| ) | |
| # Remove translation/rotation modes (count-based) | |
| n_tr = 5 if is_linear(coords, numbers) else 6 | |
| sorted_idx = np.argsort(np.abs(freqs)) | |
| vib_idx = sorted_idx[n_tr:] | |
| freqs_vib = np.sort(freqs[vib_idx]) | |
| n_imag = int(np.sum(freqs_vib < -10)) | |
| return freqs_vib, n_imag | |
| # --------------------------------------------------------------------------- | |
| # Plotting | |
| # --------------------------------------------------------------------------- | |
| def make_frequency_plot(freqs: np.ndarray) -> go.Figure: | |
| """Create Plotly stick spectrum of vibrational frequencies.""" | |
| real = freqs[freqs > 0] | |
| fig = go.Figure() | |
| if len(real) > 0: | |
| fig.add_trace(go.Bar( | |
| x=real, y=np.ones_like(real), | |
| width=3, marker_color="steelblue", | |
| hovertemplate="%{x:.1f} cm\u207b\u00b9<extra></extra>", | |
| )) | |
| fig.update_layout( | |
| xaxis_title="Frequency (cm\u207b\u00b9)", | |
| yaxis_visible=False, | |
| height=200, margin=dict(l=40, r=20, t=30, b=40), | |
| title="Vibrational Spectrum", | |
| showlegend=False, | |
| ) | |
| return fig | |
| def make_convergence_plot(trajectory: list[dict]) -> go.Figure: | |
| """Create dual-axis convergence plot (energy + max force vs step).""" | |
| steps = [t["step"] for t in trajectory] | |
| energies = [t["energy"] for t in trajectory] | |
| fmaxes = [t["fmax"] for t in trajectory] | |
| fig = make_subplots(specs=[[{"secondary_y": True}]]) | |
| fig.add_trace( | |
| go.Scatter(x=steps, y=energies, name="Energy (eV)", mode="lines+markers", | |
| marker=dict(size=4), line=dict(color="steelblue")), | |
| secondary_y=False, | |
| ) | |
| fig.add_trace( | |
| go.Scatter(x=steps, y=fmaxes, name="Max |F| (eV/\u00c5)", mode="lines+markers", | |
| marker=dict(size=4), line=dict(color="firebrick")), | |
| secondary_y=True, | |
| ) | |
| fig.update_xaxes(title_text="Step") | |
| fig.update_yaxes(title_text="Energy (eV)", secondary_y=False) | |
| fig.update_yaxes(title_text="Max |F| (eV/\u00c5)", secondary_y=True) | |
| fig.update_layout( | |
| height=280, margin=dict(l=60, r=60, t=30, b=40), | |
| legend=dict(x=0.5, y=1.15, xanchor="center", orientation="h"), | |
| ) | |
| return fig | |
| # --------------------------------------------------------------------------- | |
| # Geometry optimization | |
| # --------------------------------------------------------------------------- | |
| def run_optimization( | |
| atoms, | |
| max_steps: int, | |
| fmax_target: float, | |
| timeout: float = OPT_TIMEOUT, | |
| ) -> tuple[list[dict], bool, float]: | |
| """Run LBFGS optimization with timeout. | |
| Returns (trajectory, converged, wall_time). | |
| Reads from ASE cache to avoid double-computing. | |
| """ | |
| from ase.optimize import LBFGS | |
| opt = LBFGS(atoms, logfile=None) | |
| trajectory = [] | |
| t0 = time.time() | |
| converged = False | |
| for step in range(max_steps): | |
| if time.time() - t0 > timeout: | |
| break | |
| opt.step() | |
| e = float(atoms.calc.results["energy"]) | |
| f = atoms.calc.results["forces"] | |
| fmax = float(np.max(np.linalg.norm(f, axis=1))) | |
| trajectory.append({"step": step + 1, "energy": e, "fmax": fmax}) | |
| if fmax < fmax_target: | |
| converged = True | |
| break | |
| return trajectory, converged, time.time() - t0 | |
| # --------------------------------------------------------------------------- | |
| # Reproduction script generator | |
| # --------------------------------------------------------------------------- | |
| def _fmt_array(arr: np.ndarray, name: str) -> str: | |
| """Format numpy array as valid Python code.""" | |
| if arr.ndim == 1: | |
| return f"{name} = {arr.tolist()!r}" | |
| # 2D | |
| rows = [] | |
| for row in arr: | |
| rows.append(" [" + ", ".join(f"{v:.6f}" for v in row) + "],") | |
| return f"{name} = np.array([\n" + "\n".join(rows) + "\n])" | |
| def generate_script( | |
| coords: np.ndarray, | |
| numbers: np.ndarray, | |
| charge: int, | |
| task: str = "single_point", | |
| max_steps: int = 30, | |
| fmax: float = 0.05, | |
| compute_hessian: bool = False, | |
| ) -> str: | |
| """Generate Python reproduction script.""" | |
| lines = [ | |
| "# AIMNet2 calculation", | |
| "# Generated by https://huggingface.co/spaces/isayevlab/aimnet2-demo", | |
| "from aimnet.calculators import AIMNet2Calculator", | |
| "from aimnet.calculators.aimnet2ase import AIMNet2ASE", | |
| "from ase import Atoms", | |
| "import numpy as np", | |
| "", | |
| _fmt_array(coords, "coords"), | |
| f"numbers = {numbers.tolist()!r}", | |
| f"charge = {charge}", | |
| "", | |
| 'calc = AIMNet2ASE(AIMNet2Calculator("isayevlab/aimnet2-wb97m-d3"), charge=charge)', | |
| "atoms = Atoms(numbers=numbers, positions=coords)", | |
| "atoms.calc = calc", | |
| "", | |
| ] | |
| if task == "optimize": | |
| lines += [ | |
| "from ase.optimize import LBFGS", | |
| f"opt = LBFGS(atoms, logfile='-')", | |
| f"opt.run(fmax={fmax}, steps={max_steps})", | |
| "", | |
| "energy = atoms.get_potential_energy()", | |
| 'print(f"Optimized energy: {energy:.6f} eV")', | |
| 'print(f"Max force: {max(np.linalg.norm(atoms.get_forces(), axis=1)):.6f} eV/A")', | |
| ] | |
| else: | |
| lines += [ | |
| "energy = atoms.get_potential_energy()", | |
| "forces = atoms.get_forces()", | |
| 'charges = atoms.calc.results["charges"]', | |
| 'print(f"Energy: {energy:.6f} eV")', | |
| ] | |
| if compute_hessian: | |
| lines += [ | |
| "", | |
| "# Hessian & frequencies", | |
| "base_calc = calc.base_calc", | |
| 'hess_result = base_calc({"coord": atoms.get_positions(), ' | |
| '"numbers": atoms.numbers, "charge": float(charge)}, hessian=True)', | |
| 'hessian = hess_result["hessian"].detach().cpu().numpy()', | |
| "# Diagonalize mass-weighted Hessian for frequencies (see demo source for details)", | |
| ] | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # XYZ download helper | |
| # --------------------------------------------------------------------------- | |
| def write_xyz_file(coords: np.ndarray, numbers: np.ndarray, | |
| charges: np.ndarray | None = None, | |
| comment: str = "AIMNet2") -> str: | |
| """Write XYZ to a temp file and return the path.""" | |
| n = len(numbers) | |
| lines = [str(n), comment] | |
| for i in range(n): | |
| sym = ELEMENT_SYMBOLS.get(int(numbers[i]), "X") | |
| x, y, z = coords[i] | |
| q_str = f" {charges[i]:+.4f}" if charges is not None else "" | |
| lines.append(f"{sym:2s} {x:12.6f} {y:12.6f} {z:12.6f}{q_str}") | |
| tmp = tempfile.NamedTemporaryFile(suffix=".xyz", delete=False, mode="w") | |
| tmp.write("\n".join(lines)) | |
| tmp.close() | |
| return tmp.name | |
| # --------------------------------------------------------------------------- | |
| # Tab 1: Single-point calculation | |
| # --------------------------------------------------------------------------- | |
| def predict(input_text, input_format, charge, compute_forces, compute_hessian): | |
| """Run single-point calculation. Returns (markdown, viewer_html, freq_plot, xyz_file, script).""" | |
| charge = int(charge) | |
| empty = ("", "", None, None, "") | |
| # Parse | |
| try: | |
| coords, numbers, warning = parse_input(input_text, input_format) | |
| except Exception as e: | |
| return (f"**Parse error:** {e}", *empty[1:]) | |
| n = len(numbers) | |
| if n > MAX_ATOMS: | |
| return (f"**Error:** {n} atoms exceeds limit of {MAX_ATOMS}.", *empty[1:]) | |
| if compute_hessian and n > MAX_ATOMS_HESSIAN: | |
| return (f"**Error:** Hessian limited to {MAX_ATOMS_HESSIAN} atoms ({n} given).", *empty[1:]) | |
| # Validate elements | |
| unsupported = sorted({int(z) for z in numbers} - set(ELEMENT_SYMBOLS)) | |
| if unsupported: | |
| return (f"**Error:** Unsupported elements: {unsupported}", *empty[1:]) | |
| # SMILES charge validation | |
| smiles_warn = "" | |
| if input_format == "SMILES": | |
| _, _, fc = parse_smiles(input_text) # already parsed, just get charge | |
| if fc != charge: | |
| smiles_warn = ( | |
| f"\n> **Warning:** SMILES formal charge ({fc:+d}) != " | |
| f"supplied charge ({charge:+d}). Using supplied charge.\n" | |
| ) | |
| # Calculate | |
| try: | |
| ase_calc = make_ase_calc(charge) | |
| from ase import Atoms | |
| symbols = [ELEMENT_SYMBOLS[int(z)] for z in numbers] | |
| atoms = Atoms(symbols=symbols, positions=coords) | |
| atoms.calc = ase_calc | |
| atoms.get_potential_energy() | |
| energy_ev = float(ase_calc.results["energy"]) | |
| charges_arr = ase_calc.results.get("charges") | |
| if not np.isfinite(energy_ev): | |
| return ("**Error:** Model produced NaN/Inf. Molecule may be outside training domain.", *empty[1:]) | |
| forces_arr = None | |
| if compute_forces: | |
| atoms.get_forces() | |
| forces_arr = ase_calc.results["forces"] | |
| hessian_arr = None | |
| freqs = None | |
| n_imag = 0 | |
| if compute_hessian: | |
| data = {"coord": coords, "numbers": numbers, "charge": float(charge)} | |
| hess_result = get_base_calc()(data, hessian=True) | |
| hessian_arr = hess_result["hessian"].detach().cpu().numpy() | |
| freqs, n_imag = compute_frequencies(hessian_arr, numbers, coords) | |
| except Exception as e: | |
| import traceback | |
| return (f"**Calculation error:** {e}\n```\n{traceback.format_exc()}\n```", *empty[1:]) | |
| # Build outputs | |
| viewer_html = build_viewer_html(coords, numbers, charges_arr) | |
| # Results markdown | |
| energy_kcal = energy_ev * EV_TO_KCAL | |
| energy_ha = energy_ev / HARTREE_TO_EV | |
| md = [] | |
| md.append("## AIMNet2 Results\n") | |
| if smiles_warn: | |
| md.append(smiles_warn) | |
| md.append(f"**Atoms:** {n} | **Charge:** {charge:+d}\n") | |
| md.append("### Energy\n| Unit | Value |\n|------|------:|") | |
| md.append(f"| eV | {energy_ev:.6f} |") | |
| md.append(f"| kcal/mol | {energy_kcal:.4f} |") | |
| md.append(f"| Hartree | {energy_ha:.8f} |\n") | |
| if charges_arr is not None: | |
| md.append("### Partial Charges (e)\n| # | Elem | Charge |\n|--:|:----:|-------:|") | |
| for i, (z, q) in enumerate(zip(numbers, charges_arr)): | |
| sym = ELEMENT_SYMBOLS.get(int(z), "?") | |
| md.append(f"| {i+1} | {sym} | {q:+.4f} |") | |
| md.append(f"\n*Sum: {float(np.sum(charges_arr)):+.4f} e*\n") | |
| if forces_arr is not None: | |
| max_f = float(np.max(np.linalg.norm(forces_arr, axis=1))) | |
| rms_f = float(np.sqrt(np.mean(forces_arr**2))) | |
| md.append("### Forces (eV/A)\n| Metric | Value |\n|--------|------:|") | |
| md.append(f"| Max |F| | {max_f:.6f} |") | |
| md.append(f"| RMS | {rms_f:.6f} |") | |
| if input_format == "SMILES": | |
| md.append("\n> *Geometry from MMFF, not AIMNet2-optimized.*\n") | |
| freq_plot = None | |
| if freqs is not None: | |
| real_f = freqs[freqs > 0] | |
| imag_f = freqs[freqs < 0] | |
| md.append("### Vibrational Frequencies\n") | |
| if max_f > 0.05 if forces_arr is not None else True: | |
| md.append("> *Frequencies at non-stationary point. Low modes may be unreliable.*\n") | |
| if n_imag > 0: | |
| md.append(f"> **{n_imag} imaginary frequency(ies)** -- not a true minimum.\n") | |
| if len(real_f) > 0: | |
| md.append("```") | |
| for j, f in enumerate(real_f): | |
| md.append(f" {j+1:3d}: {f:10.2f} cm-1") | |
| md.append("```") | |
| if len(imag_f) > 0: | |
| md.append("\nImaginary:\n```") | |
| for j, f in enumerate(imag_f): | |
| md.append(f" {j+1:3d}: {abs(f):10.2f}i cm-1") | |
| md.append("```") | |
| freq_plot = make_frequency_plot(freqs) | |
| md.append("\n---") | |
| md.append("*AIMNet2 wB97M-D3 | [Model](https://huggingface.co/isayevlab/aimnet2-wb97m-d3) | [Paper](https://doi.org/10.1039/D4SC08572H)*") | |
| xyz_file = write_xyz_file(coords, numbers, charges_arr, | |
| comment=f"Energy: {energy_ev:.6f} eV") | |
| script = generate_script(coords, numbers, charge, "single_point", | |
| compute_hessian=compute_hessian) | |
| return "\n".join(md), viewer_html, freq_plot, xyz_file, script | |
| # --------------------------------------------------------------------------- | |
| # Tab 2: Geometry optimization | |
| # --------------------------------------------------------------------------- | |
| def optimize(input_text, input_format, charge, max_steps, fmax_target, | |
| compute_freqs): | |
| """Run geometry optimization. Returns (md, viewer_html, conv_plot, freq_plot, xyz_file, script).""" | |
| charge = int(charge) | |
| max_steps = int(max_steps) | |
| fmax_target = float(fmax_target) | |
| empty = ("", "", None, None, None, "") | |
| # Auto-tighten fmax when frequencies requested | |
| if compute_freqs and fmax_target > 0.02: | |
| fmax_target = 0.02 | |
| # Validate fmax | |
| if not 0.01 <= fmax_target <= 1.0: | |
| return ("**Error:** fmax must be between 0.01 and 1.0 eV/A.", *empty[1:]) | |
| # Parse | |
| try: | |
| coords, numbers, _ = parse_input(input_text, input_format) | |
| except Exception as e: | |
| return (f"**Parse error:** {e}", *empty[1:]) | |
| n = len(numbers) | |
| if n > MAX_ATOMS_OPT: | |
| return (f"**Error:** Optimization limited to {MAX_ATOMS_OPT} atoms ({n} given).", *empty[1:]) | |
| unsupported = sorted({int(z) for z in numbers} - set(ELEMENT_SYMBOLS)) | |
| if unsupported: | |
| return (f"**Error:** Unsupported elements: {unsupported}", *empty[1:]) | |
| # Optimize | |
| try: | |
| ase_calc = make_ase_calc(charge) | |
| from ase import Atoms | |
| symbols = [ELEMENT_SYMBOLS[int(z)] for z in numbers] | |
| atoms = Atoms(symbols=symbols, positions=coords) | |
| atoms.calc = ase_calc | |
| # Initial energy/forces (must request forces explicitly) | |
| atoms.get_forces() # triggers full calc including energy | |
| e0 = float(ase_calc.results["energy"]) | |
| f0 = ase_calc.results["forces"] | |
| fmax0 = float(np.max(np.linalg.norm(f0, axis=1))) | |
| trajectory, converged, wall_time = run_optimization( | |
| atoms, max_steps, fmax_target | |
| ) | |
| opt_coords = atoms.get_positions() | |
| e_final = trajectory[-1]["energy"] if trajectory else e0 | |
| fmax_final = trajectory[-1]["fmax"] if trajectory else fmax0 | |
| charges_arr = ase_calc.results.get("charges") | |
| if not np.isfinite(e_final): | |
| return ("**Error:** Model produced NaN/Inf during optimization.", *empty[1:]) | |
| # Frequencies at optimized geometry | |
| freqs = None | |
| n_imag = 0 | |
| if compute_freqs: | |
| data = { | |
| "coord": opt_coords, | |
| "numbers": atoms.numbers, | |
| "charge": float(charge), | |
| } | |
| hess_result = get_base_calc()(data, hessian=True) | |
| hessian = hess_result["hessian"].detach().cpu().numpy() | |
| freqs, n_imag = compute_frequencies(hessian, atoms.numbers, opt_coords) | |
| except Exception as e: | |
| import traceback | |
| return (f"**Calculation error:** {e}\n```\n{traceback.format_exc()}\n```", *empty[1:]) | |
| # Build outputs | |
| viewer_html = build_viewer_html(opt_coords, numbers, charges_arr) | |
| conv_plot = make_convergence_plot(trajectory) if trajectory else None | |
| md = [] | |
| md.append("## Optimization Results\n") | |
| status = "Converged" if converged else "Not converged" | |
| if not converged: | |
| md.append(f"> **{status}** after {len(trajectory)} steps / {wall_time:.1f}s " | |
| f"(final fmax: {fmax_final:.4f} eV/A)\n") | |
| else: | |
| md.append(f"**{status}** in {len(trajectory)} steps ({wall_time:.1f}s)\n") | |
| md.append("| Property | Initial | Final |") | |
| md.append("|----------|--------:|------:|") | |
| md.append(f"| Energy (eV) | {e0:.6f} | {e_final:.6f} |") | |
| md.append(f"| Energy (kcal/mol) | {e0*EV_TO_KCAL:.4f} | {e_final*EV_TO_KCAL:.4f} |") | |
| md.append(f"| Max |F| (eV/A) | {fmax0:.6f} | {fmax_final:.6f} |") | |
| md.append(f"| dE (eV) | | {e_final - e0:.6f} |") | |
| md.append("") | |
| if charges_arr is not None: | |
| md.append("### Partial Charges (e)\n| # | Elem | Charge |\n|--:|:----:|-------:|") | |
| for i, (z, q) in enumerate(zip(numbers, charges_arr)): | |
| sym = ELEMENT_SYMBOLS.get(int(z), "?") | |
| md.append(f"| {i+1} | {sym} | {q:+.4f} |") | |
| md.append(f"\n*Sum: {float(np.sum(charges_arr)):+.4f} e*\n") | |
| freq_plot = None | |
| if freqs is not None: | |
| real_f = freqs[freqs > 0] | |
| imag_f = freqs[freqs < 0] | |
| md.append("### Vibrational Frequencies\n") | |
| if n_imag > 0: | |
| md.append(f"> **{n_imag} imaginary frequency(ies)** -- not a true minimum.\n") | |
| if len(real_f) > 0: | |
| md.append("```") | |
| for j, f in enumerate(real_f): | |
| md.append(f" {j+1:3d}: {f:10.2f} cm-1") | |
| md.append("```") | |
| if len(imag_f) > 0: | |
| md.append("\nImaginary:\n```") | |
| for j, f in enumerate(imag_f): | |
| md.append(f" {j+1:3d}: {abs(f):10.2f}i cm-1") | |
| md.append("```") | |
| freq_plot = make_frequency_plot(freqs) | |
| md.append("\n---") | |
| md.append("*AIMNet2 wB97M-D3 | [Model](https://huggingface.co/isayevlab/aimnet2-wb97m-d3)*") | |
| xyz_file = write_xyz_file(opt_coords, numbers, charges_arr, | |
| comment=f"Optimized, E={e_final:.6f} eV, fmax={fmax_final:.6f}") | |
| script = generate_script(coords, numbers, charge, "optimize", | |
| max_steps=max_steps, fmax=fmax_target, | |
| compute_hessian=compute_freqs) | |
| return "\n".join(md), viewer_html, conv_plot, freq_plot, xyz_file, script | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| CALC_EXAMPLES = [ | |
| ["CCO", "SMILES", 0, True, False], | |
| ["c1ccccc1", "SMILES", 0, True, False], | |
| ["CC(=O)O", "SMILES", 0, True, False], | |
| ["[NH4+]", "SMILES", 1, True, False], | |
| ["CC(=O)[O-]", "SMILES", -1, True, False], | |
| ["O=C(O)c1ccccc1", "SMILES", 0, True, False], | |
| ["O", "SMILES", 0, True, True], | |
| ] | |
| OPT_EXAMPLES = [ | |
| ["CCO", "SMILES", 0, 30, 0.05, False], | |
| ["O", "SMILES", 0, 30, 0.05, True], | |
| ] | |
| VIEWER_PLACEHOLDER = """<div style=" | |
| height:420px;display:flex;flex-direction:column; | |
| align-items:center;justify-content:center; | |
| background:linear-gradient(135deg,#f0f4ff 0%,#e8eeff 100%); | |
| border-radius:12px;border:1px solid #dbeafe; | |
| font-family:'Inter',system-ui,sans-serif; | |
| "> | |
| <svg width="48" height="48" viewBox="0 0 24 24" fill="none" stroke="#94a3b8" stroke-width="1.5"> | |
| <path d="M12 2L2 7l10 5 10-5-10-5z"/><path d="M2 17l10 5 10-5"/> | |
| <path d="M2 12l10 5 10-5"/> | |
| </svg> | |
| <p style="color:#94a3b8;margin-top:12px;font-size:14px;"> | |
| Enter a molecule and click Calculate or Optimize | |
| </p> | |
| </div>""" | |
| # Custom CSS for clean, modern look | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| font-family: 'Inter', system-ui, -apple-system, sans-serif !important; | |
| } | |
| h1 { | |
| font-weight: 700 !important; | |
| letter-spacing: -0.02em !important; | |
| color: #1e3a5f !important; | |
| } | |
| .header-subtitle { | |
| color: #64748b !important; | |
| font-size: 15px !important; | |
| line-height: 1.6 !important; | |
| margin-bottom: 4px !important; | |
| } | |
| .header-links a { | |
| color: #3b82f6 !important; | |
| text-decoration: none !important; | |
| font-size: 13px !important; | |
| } | |
| .header-links a:hover { text-decoration: underline !important; } | |
| """ | |
| # Theme | |
| THEME = gr.themes.Base( | |
| primary_hue=gr.themes.colors.blue, | |
| secondary_hue=gr.themes.colors.slate, | |
| neutral_hue=gr.themes.colors.slate, | |
| font=gr.themes.GoogleFont("Inter"), | |
| font_mono=gr.themes.GoogleFont("Fira Code"), | |
| radius_size=gr.themes.sizes.radius_md, | |
| ).set( | |
| body_background_fill="#fafbfc", | |
| block_background_fill="white", | |
| block_border_width="1px", | |
| block_border_color="#e2e8f0", | |
| block_shadow="0 1px 3px rgba(0,0,0,0.04)", | |
| block_label_text_size="13px", | |
| block_label_text_weight="500", | |
| block_label_text_color="#475569", | |
| button_primary_background_fill="#2563eb", | |
| button_primary_background_fill_hover="#1d4ed8", | |
| button_primary_text_color="white", | |
| input_border_color="#e2e8f0", | |
| input_background_fill="white", | |
| checkbox_label_text_size="13px", | |
| ) | |
| with gr.Blocks(title="AIMNet2 Demo", theme=THEME, css=CUSTOM_CSS) as demo: | |
| # --- Header --- | |
| gr.Markdown("# AIMNet2") | |
| gr.Markdown( | |
| '<span class="header-subtitle">' | |
| "Fast neural network interatomic potential — " | |
| "energy, forces, partial charges, geometry optimization, vibrational frequencies. " | |
| "Atoms colored by predicted charge." | |
| "</span>", | |
| ) | |
| gr.Markdown( | |
| '<span class="header-links">' | |
| '[Model Card](https://huggingface.co/isayevlab/aimnet2-wb97m-d3) • ' | |
| '[Paper](https://doi.org/10.1039/D4SC08572H) • ' | |
| '[GitHub](https://github.com/isayevlab/aimnetcentral) • ' | |
| '`pip install "aimnet[hf]"`' | |
| "</span>", | |
| ) | |
| # --- Molecule Input (compact row) --- | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2, min_width=280): | |
| input_text = gr.Textbox( | |
| lines=4, label="Molecule", | |
| placeholder="Enter SMILES (e.g. CCO), paste XYZ coordinates, or PDB block...", | |
| ) | |
| with gr.Column(scale=1, min_width=180): | |
| input_format = gr.Radio( | |
| ["SMILES", "XYZ", "PDB"], value="SMILES", label="Format", | |
| ) | |
| with gr.Row(): | |
| charge_input = gr.Number(value=0, precision=0, label="Charge", scale=1) | |
| file_upload = gr.File( | |
| label="Upload", | |
| file_types=[".xyz", ".pdb", ".sdf", ".mol"], | |
| scale=1, | |
| ) | |
| # File upload handler | |
| def on_file_upload(file_obj): | |
| if file_obj is None: | |
| return gr.update(), gr.update() | |
| try: | |
| text, fmt = handle_file_upload(file_obj) | |
| gr.Info(f"Loaded file ({fmt} format)") | |
| return gr.update(value=text), gr.update(value=fmt) | |
| except Exception as e: | |
| gr.Warning(f"File upload failed: {e}") | |
| return gr.update(), gr.update() | |
| file_upload.change( | |
| on_file_upload, inputs=[file_upload], outputs=[input_text, input_format] | |
| ) | |
| # --- Tabs --- | |
| with gr.Tabs() as tabs: | |
| # ===== Tab 1: Calculate ===== | |
| with gr.TabItem("Single Point", id=0): | |
| with gr.Row(equal_height=False): | |
| # Left: controls | |
| with gr.Column(scale=1, min_width=220): | |
| calc_forces = gr.Checkbox(value=True, label="Forces") | |
| calc_hessian = gr.Checkbox( | |
| value=False, label="Hessian & Frequencies" | |
| ) | |
| calc_btn = gr.Button( | |
| "Calculate", variant="primary", size="lg", | |
| ) | |
| gr.Markdown("##### Try an example") | |
| gr.Examples( | |
| examples=CALC_EXAMPLES, | |
| inputs=[input_text, input_format, charge_input, | |
| calc_forces, calc_hessian], | |
| label="", | |
| ) | |
| # Right: results | |
| with gr.Column(scale=3, min_width=500): | |
| calc_viewer = gr.HTML(value=VIEWER_PLACEHOLDER) | |
| calc_results = gr.Markdown() | |
| calc_freq_plot = gr.Plot(visible=False) | |
| with gr.Row(): | |
| with gr.Accordion("Download XYZ", open=False): | |
| calc_xyz = gr.File(interactive=False) | |
| with gr.Accordion("Python code", open=False): | |
| calc_script = gr.Code(language="python") | |
| def calc_wrapper(text, fmt, charge, forces, hessian): | |
| md, viewer, fplot, xyz, script = predict( | |
| text, fmt, charge, forces, hessian | |
| ) | |
| return ( | |
| md, | |
| viewer or VIEWER_PLACEHOLDER, | |
| gr.update(value=fplot, visible=fplot is not None), | |
| xyz, | |
| script, | |
| ) | |
| calc_btn.click( | |
| calc_wrapper, | |
| inputs=[input_text, input_format, charge_input, | |
| calc_forces, calc_hessian], | |
| outputs=[calc_results, calc_viewer, calc_freq_plot, | |
| calc_xyz, calc_script], | |
| ) | |
| # ===== Tab 2: Optimize ===== | |
| with gr.TabItem("Optimize", id=1): | |
| with gr.Row(equal_height=False): | |
| # Left: controls | |
| with gr.Column(scale=1, min_width=220): | |
| opt_steps = gr.Slider( | |
| 10, 50, value=30, step=1, label="Max steps", | |
| ) | |
| opt_fmax = gr.Number( | |
| value=0.05, label="fmax (eV/A)", | |
| minimum=0.01, maximum=1.0, | |
| ) | |
| opt_freqs = gr.Checkbox( | |
| value=False, label="Frequencies at minimum", | |
| ) | |
| opt_btn = gr.Button( | |
| "Optimize", variant="primary", size="lg", | |
| ) | |
| gr.Markdown("##### Try an example") | |
| gr.Examples( | |
| examples=OPT_EXAMPLES, | |
| inputs=[input_text, input_format, charge_input, | |
| opt_steps, opt_fmax, opt_freqs], | |
| label="", | |
| ) | |
| # Right: results | |
| with gr.Column(scale=3, min_width=500): | |
| opt_viewer = gr.HTML(value=VIEWER_PLACEHOLDER) | |
| opt_conv_plot = gr.Plot(label="Convergence") | |
| opt_results = gr.Markdown() | |
| opt_freq_plot = gr.Plot(visible=False) | |
| with gr.Row(): | |
| with gr.Accordion("Download XYZ", open=False): | |
| opt_xyz = gr.File(interactive=False) | |
| with gr.Accordion("Python code", open=False): | |
| opt_script = gr.Code(language="python") | |
| def opt_wrapper(text, fmt, charge, steps, fmax, freqs): | |
| md, viewer, conv, fplot, xyz, script = optimize( | |
| text, fmt, charge, steps, fmax, freqs | |
| ) | |
| return ( | |
| md, | |
| viewer or VIEWER_PLACEHOLDER, | |
| conv, | |
| gr.update(value=fplot, visible=fplot is not None), | |
| xyz, | |
| script, | |
| ) | |
| opt_btn.click( | |
| opt_wrapper, | |
| inputs=[input_text, input_format, charge_input, | |
| opt_steps, opt_fmax, opt_freqs], | |
| outputs=[opt_results, opt_viewer, opt_conv_plot, | |
| opt_freq_plot, opt_xyz, opt_script], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |