Spaces:
Running
Running
| from pathlib import Path | |
| from typing import Literal | |
| import numpy as np | |
| import pandas as pd | |
| from ase import Atoms, units | |
| from ase.io import read | |
| from loguru import logger | |
| from tqdm.auto import tqdm | |
| def get_runtime_stats(traj: list[Atoms], atoms0: Atoms): | |
| """Compute runtime statistics for an ASE trajectory. | |
| Parameters | |
| ---------- | |
| traj : list[ase.Atoms] | |
| Sequence of ASE Atoms objects representing trajectory frames. Each frame | |
| is expected to have an `info` dict containing at least the keys | |
| 'restart', 'datetime', and 'step'. The first frame should provide | |
| 'target_steps'. | |
| atoms0 : ase.Atoms | |
| Reference Atoms object (typically the first frame) used to compute | |
| center-of-mass drift and to determine the number of atoms. | |
| Returns | |
| ------- | |
| dict | |
| A dictionary containing the following keys: | |
| - 'natoms': int, number of atoms from atoms0. | |
| - 'total_time_seconds': float, total wall-clock time summed across | |
| unique restart blocks (seconds). | |
| - 'total_steps': int, total MD steps summed across unique restart blocks. | |
| - 'steps_per_second': float, throughput (0 if total_time_seconds == 0). | |
| - 'seconds_per_step': float, average seconds per step (inf if total_steps == 0). | |
| - 'seconds_per_step_per_atom': float, seconds per step normalized by atom count. | |
| - 'energies': list of potential energies for successfully parsed frames. | |
| - 'kinetic_energies': list of kinetic energies. | |
| - 'temperatures': list of temperatures. | |
| - 'pressures': list of mean pressures (may be empty if not available). | |
| - 'target_steps': target number of steps taken from traj[0].info. | |
| - 'final_step': last recorded step number (0 if no valid frames). | |
| - 'timestep': array of step numbers for valid frames. | |
| - 'com_drifts': list of center-of-mass drift vectors relative to atoms0. | |
| Notes | |
| ----- | |
| Frames that raise exceptions when querying potential energy are skipped. | |
| Unique restart blocks are identified by atoms.info['restart'] and used to | |
| compute contiguous time and step differences across restarts. | |
| """ | |
| restarts = [] | |
| steps, times = [], [] | |
| Ts, Ps, Es, KEs = [], [], [], [] | |
| com_drifts = [] | |
| for atoms in traj: | |
| try: | |
| energy = atoms.get_potential_energy() | |
| assert np.isfinite(energy), f"invalid energy: {energy}" | |
| except Exception: | |
| continue | |
| restarts.append(atoms.info["restart"]) | |
| times.append(atoms.info["datetime"]) | |
| steps.append(atoms.info["step"]) | |
| Es.append(energy) | |
| KEs.append(atoms.get_kinetic_energy()) | |
| Ts.append(atoms.get_temperature()) | |
| try: | |
| Ps.append(atoms.get_stress()[:3].mean()) | |
| except Exception: | |
| Ps.append(np.nan) | |
| com_drifts.append( | |
| (atoms.get_center_of_mass() - atoms0.get_center_of_mass()).tolist() | |
| ) | |
| restarts = np.array(restarts) | |
| times = np.array(times) | |
| steps = np.array(steps) | |
| # Identify unique blocks | |
| unique_restarts = np.unique(restarts) | |
| total_time_seconds = 0 | |
| total_steps = 0 | |
| # Iterate over unique blocks to calculate averages | |
| for block in unique_restarts: | |
| # Get the indices corresponding to the current block | |
| # indices = np.where(restarts == block)[0] | |
| indices = restarts == block | |
| # Extract the corresponding data values | |
| block_time = times[indices][-1] - times[indices][0] | |
| total_time_seconds += block_time.total_seconds() | |
| total_steps += steps[indices][-1] - steps[indices][0] | |
| target_steps = traj[1].info["target_steps"] | |
| natoms = len(atoms0) | |
| return { | |
| "natoms": natoms, | |
| "total_time_seconds": total_time_seconds, | |
| "total_steps": total_steps, | |
| "steps_per_second": total_steps / total_time_seconds | |
| if total_time_seconds != 0 | |
| else 0, | |
| "seconds_per_step": total_time_seconds / total_steps | |
| if total_steps != 0 | |
| else float("inf"), | |
| "seconds_per_step_per_atom": total_time_seconds / total_steps / natoms | |
| if total_steps != 0 | |
| else float("inf"), | |
| "energies": Es, | |
| "kinetic_energies": KEs, | |
| "temperatures": Ts, | |
| "pressures": Ps, | |
| "target_steps": target_steps, | |
| "final_step": steps[-1] if len(steps) != 0 else 0, | |
| "timestep": steps, | |
| "com_drifts": com_drifts, | |
| } | |
| def gather_results( | |
| run_dir: Path, prefix: str, run_type: Literal["nvt", "npt"] | |
| ) -> pd.DataFrame: | |
| df = pd.DataFrame() | |
| run_dir = Path(run_dir) | |
| files = list(run_dir.glob(f"{prefix}_*{run_type}.traj")) | |
| for fpath in tqdm(files, desc=prefix): | |
| try: | |
| traj = read(fpath, index=":") | |
| except Exception as e: | |
| logger.warning(f"Error reading {fpath}: {e}") | |
| continue | |
| try: | |
| stats = get_runtime_stats(traj, atoms0=traj[0]) | |
| df = pd.concat( | |
| [ | |
| df, | |
| pd.DataFrame( | |
| { | |
| # "model": model_name, | |
| "formula": traj[0].get_chemical_formula(), | |
| "normalized_timestep": stats["timestep"] | |
| / stats["target_steps"], | |
| "normalized_final_step": stats["final_step"] | |
| / stats["target_steps"], | |
| "pressure": np.array(stats["pressures"]) / units.GPa, | |
| } | |
| | stats | |
| ), | |
| ], | |
| ignore_index=True, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Error processing {fpath}: {e}") | |
| continue | |
| return df |