Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| import os | |
| from ase.io import read | |
| from ase import units | |
| from ase.optimize import LBFGS | |
| from ase.md.verlet import VelocityVerlet | |
| from ase.md.velocitydistribution import MaxwellBoltzmannDistribution | |
| from ase.md import MDLogger | |
| from ase.io.trajectory import Trajectory | |
| import py3Dmol | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| # ----------------------------- | |
| # Global model | |
| # ----------------------------- | |
| model_calc = None | |
| def load_orbmol_model(): | |
| """Load OrbMol model once""" | |
| global model_calc | |
| if model_calc is None: | |
| try: | |
| print("Loading OrbMol model...") | |
| orbff = pretrained.orb_v3_conservative_inf_omat( | |
| device="cpu", | |
| precision="float32-high" | |
| ) | |
| model_calc = ORBCalculator(orbff, device="cpu") | |
| print("β OrbMol model loaded successfully") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| model_calc = None | |
| return model_calc | |
| # ----------------------------- | |
| # Single-point calculation | |
| # ----------------------------- | |
| def predict_molecule(xyz_content, charge=0, spin_multiplicity=1): | |
| try: | |
| calc = load_orbmol_model() | |
| if calc is None: | |
| return "β Error: Could not load OrbMol model", "" | |
| if not xyz_content.strip(): | |
| return "β Error: Please enter XYZ coordinates", "" | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.xyz', delete=False) as f: | |
| f.write(xyz_content) | |
| xyz_file = f.name | |
| atoms = read(xyz_file) | |
| atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)} | |
| atoms.calc = calc | |
| energy = atoms.get_potential_energy() | |
| forces = atoms.get_forces() | |
| result = f"π **Total Energy**: {energy:.6f} eV\n\nβ‘ **Atomic Forces**:\n" | |
| for i, f in enumerate(forces): | |
| result += f"Atom {i+1}: [{f[0]:.4f}, {f[1]:.4f}, {f[2]:.4f}] eV/Γ \n" | |
| max_force = np.max(np.linalg.norm(forces, axis=1)) | |
| result += f"\nπ **Max Force**: {max_force:.4f} eV/Γ " | |
| os.unlink(xyz_file) | |
| return result, "β Calculation completed with OrbMol" | |
| except Exception as e: | |
| return f"β Error during calculation: {str(e)}", "Error" | |
| # ----------------------------- | |
| # Helper: convert trajectory β HTML animation | |
| # ----------------------------- | |
| def traj_to_html(traj_file): | |
| traj = Trajectory(traj_file) | |
| view = py3Dmol.view(width=400, height=400) | |
| for atoms in traj: | |
| symbols = atoms.get_chemical_symbols() | |
| xyz = atoms.get_positions() | |
| mol = "" | |
| for s, (x, y, z) in zip(symbols, xyz): | |
| mol += f"{s} {x} {y} {z}\n" | |
| view.addModel(mol, "xyz") | |
| view.setStyle({"stick": {}}) | |
| view.zoomTo() | |
| view.animate({"loop": "forward"}) | |
| return view._make_html() | |
| # ----------------------------- | |
| # Molecular dynamics simulation | |
| # ----------------------------- | |
| def run_md(xyz_content, charge=0, spin_multiplicity=1, steps=100, temperature=300, timestep=1.0): | |
| try: | |
| calc = load_orbmol_model() | |
| if calc is None: | |
| return "β Error: Could not load OrbMol model", "" | |
| if not xyz_content.strip(): | |
| return "β Error: Please enter XYZ coordinates", "" | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.xyz', delete=False) as f: | |
| f.write(xyz_content) | |
| xyz_file = f.name | |
| atoms = read(xyz_file) | |
| atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)} | |
| atoms.calc = calc | |
| # Pre-relaxation | |
| opt = LBFGS(atoms) | |
| opt.run(fmax=0.05, steps=20) | |
| # Velocities | |
| MaxwellBoltzmannDistribution(atoms, temperature_K=2 * temperature) | |
| # MD setup | |
| dyn = VelocityVerlet(atoms, timestep=timestep * units.fs) | |
| traj_file = tempfile.NamedTemporaryFile(suffix=".traj", delete=False) | |
| traj = Trajectory(traj_file.name, "w", atoms) | |
| dyn.attach(traj.write, interval=1) | |
| dyn.run(steps) | |
| html = traj_to_html(traj_file.name) | |
| os.unlink(xyz_file) | |
| return f"β MD completed: {steps} steps at {temperature} K", html | |
| except Exception as e: | |
| return f"β Error during MD simulation: {str(e)}", "" | |
| # ----------------------------- | |
| # Gradio UI | |
| # ----------------------------- | |
| with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol + MD Demo") as demo: | |
| gr.Markdown("# OrbMol Demo with Molecular Dynamics") | |
| with gr.Tab("Single Point Energy"): | |
| xyz_input = gr.Textbox(label="XYZ Coordinates", lines=12) | |
| charge_input = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge") | |
| spin_input = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity") | |
| run_btn = gr.Button("Run OrbMol Calculation") | |
| results_output = gr.Textbox(label="Results", lines=15) | |
| status_output = gr.Textbox(label="Status") | |
| run_btn.click( | |
| predict_molecule, | |
| inputs=[xyz_input, charge_input, spin_input], | |
| outputs=[results_output, status_output], | |
| ) | |
| with gr.Tab("Molecular Dynamics"): | |
| xyz_input_md = gr.Textbox(label="XYZ Coordinates", lines=12) | |
| charge_input_md = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge") | |
| spin_input_md = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity") | |
| steps_input = gr.Slider(value=100, minimum=10, maximum=1000, step=10, label="Steps") | |
| temp_input = gr.Slider(value=300, minimum=10, maximum=1000, step=10, label="Temperature (K)") | |
| timestep_input = gr.Slider(value=1.0, minimum=0.1, maximum=5.0, step=0.1, label="Timestep (fs)") | |
| run_md_btn = gr.Button("Run MD Simulation") | |
| md_status = gr.Textbox(label="MD Status", lines=2) | |
| md_view = gr.HTML() | |
| run_md_btn.click( | |
| run_md, | |
| inputs=[xyz_input_md, charge_input_md, spin_input_md, steps_input, temp_input, timestep_input], | |
| outputs=[md_status, md_view], | |
| ) | |
| print("π Starting OrbMol model loading...") | |
| load_orbmol_model() | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |