github-actions[ci]
Clean sync from main branch - 2025-10-16 23:00:12
afe68b4
from pathlib import Path
from typing import Literal
import numpy as np
import pandas as pd
from ase import Atoms, units
from ase.io import read
from loguru import logger
from tqdm.auto import tqdm
def get_runtime_stats(traj: list[Atoms], atoms0: Atoms):
"""Compute runtime statistics for an ASE trajectory.
Parameters
----------
traj : list[ase.Atoms]
Sequence of ASE Atoms objects representing trajectory frames. Each frame
is expected to have an `info` dict containing at least the keys
'restart', 'datetime', and 'step'. The first frame should provide
'target_steps'.
atoms0 : ase.Atoms
Reference Atoms object (typically the first frame) used to compute
center-of-mass drift and to determine the number of atoms.
Returns
-------
dict
A dictionary containing the following keys:
- 'natoms': int, number of atoms from atoms0.
- 'total_time_seconds': float, total wall-clock time summed across
unique restart blocks (seconds).
- 'total_steps': int, total MD steps summed across unique restart blocks.
- 'steps_per_second': float, throughput (0 if total_time_seconds == 0).
- 'seconds_per_step': float, average seconds per step (inf if total_steps == 0).
- 'seconds_per_step_per_atom': float, seconds per step normalized by atom count.
- 'energies': list of potential energies for successfully parsed frames.
- 'kinetic_energies': list of kinetic energies.
- 'temperatures': list of temperatures.
- 'pressures': list of mean pressures (may be empty if not available).
- 'target_steps': target number of steps taken from traj[0].info.
- 'final_step': last recorded step number (0 if no valid frames).
- 'timestep': array of step numbers for valid frames.
- 'com_drifts': list of center-of-mass drift vectors relative to atoms0.
Notes
-----
Frames that raise exceptions when querying potential energy are skipped.
Unique restart blocks are identified by atoms.info['restart'] and used to
compute contiguous time and step differences across restarts.
"""
restarts = []
steps, times = [], []
Ts, Ps, Es, KEs = [], [], [], []
com_drifts = []
for atoms in traj:
try:
energy = atoms.get_potential_energy()
assert np.isfinite(energy), f"invalid energy: {energy}"
except Exception:
continue
restarts.append(atoms.info["restart"])
times.append(atoms.info["datetime"])
steps.append(atoms.info["step"])
Es.append(energy)
KEs.append(atoms.get_kinetic_energy())
Ts.append(atoms.get_temperature())
try:
Ps.append(atoms.get_stress()[:3].mean())
except Exception:
Ps.append(np.nan)
com_drifts.append(
(atoms.get_center_of_mass() - atoms0.get_center_of_mass()).tolist()
)
restarts = np.array(restarts)
times = np.array(times)
steps = np.array(steps)
# Identify unique blocks
unique_restarts = np.unique(restarts)
total_time_seconds = 0
total_steps = 0
# Iterate over unique blocks to calculate averages
for block in unique_restarts:
# Get the indices corresponding to the current block
# indices = np.where(restarts == block)[0]
indices = restarts == block
# Extract the corresponding data values
block_time = times[indices][-1] - times[indices][0]
total_time_seconds += block_time.total_seconds()
total_steps += steps[indices][-1] - steps[indices][0]
target_steps = traj[1].info["target_steps"]
natoms = len(atoms0)
return {
"natoms": natoms,
"total_time_seconds": total_time_seconds,
"total_steps": total_steps,
"steps_per_second": total_steps / total_time_seconds
if total_time_seconds != 0
else 0,
"seconds_per_step": total_time_seconds / total_steps
if total_steps != 0
else float("inf"),
"seconds_per_step_per_atom": total_time_seconds / total_steps / natoms
if total_steps != 0
else float("inf"),
"energies": Es,
"kinetic_energies": KEs,
"temperatures": Ts,
"pressures": Ps,
"target_steps": target_steps,
"final_step": steps[-1] if len(steps) != 0 else 0,
"timestep": steps,
"com_drifts": com_drifts,
}
def gather_results(
run_dir: Path, prefix: str, run_type: Literal["nvt", "npt"]
) -> pd.DataFrame:
df = pd.DataFrame()
run_dir = Path(run_dir)
files = list(run_dir.glob(f"{prefix}_*{run_type}.traj"))
for fpath in tqdm(files, desc=prefix):
try:
traj = read(fpath, index=":")
except Exception as e:
logger.warning(f"Error reading {fpath}: {e}")
continue
try:
stats = get_runtime_stats(traj, atoms0=traj[0])
df = pd.concat(
[
df,
pd.DataFrame(
{
# "model": model_name,
"formula": traj[0].get_chemical_formula(),
"normalized_timestep": stats["timestep"]
/ stats["target_steps"],
"normalized_final_step": stats["final_step"]
/ stats["target_steps"],
"pressure": np.array(stats["pressures"]) / units.GPa,
}
| stats
),
],
ignore_index=True,
)
except Exception as e:
logger.warning(f"Error processing {fpath}: {e}")
continue
return df