Spaces:
Running
Running
| """ | |
| Define equivariance testing task. | |
| """ | |
| from __future__ import annotations | |
| from collections.abc import Sequence | |
| from pathlib import Path | |
| import numpy as np | |
| from ase import Atoms | |
| from ase.calculators.calculator import BaseCalculator | |
| from prefect import task | |
| from prefect.cache_policies import INPUTS, TASK_SOURCE | |
| from scipy.spatial.transform import Rotation as R | |
| from tqdm import tqdm | |
| def generate_random_unit_vector(): | |
| """Generate a random unit vector.""" | |
| vec = np.random.normal(0, 1, 3) | |
| return vec / np.linalg.norm(vec) | |
| def rotate_molecule_arbitrary( | |
| atoms: Atoms, angle: float, axis: np.ndarray | |
| ) -> tuple[Atoms, np.ndarray]: | |
| """Rotate molecule around arbitrary axis.""" | |
| rotated_atoms = atoms.copy() | |
| positions = rotated_atoms.get_positions() | |
| rot = R.from_rotvec(np.radians(angle) * axis) | |
| rotation_mat = rot.as_matrix() | |
| rotated_positions = rot.apply(positions) | |
| rotated_atoms.set_positions(rotated_positions) | |
| cell = atoms.get_cell() | |
| rotated_cell = rot.apply(cell) | |
| rotated_atoms.set_cell(rotated_cell) | |
| return rotated_atoms, rotation_mat | |
| def compare_forces( | |
| original_forces: np.ndarray, | |
| rotated_forces: np.ndarray, | |
| rotation_mat: np.ndarray, | |
| zero_threshold: float = 1e-10, | |
| ) -> tuple[float, np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Compare forces before and after rotation, with handling of 0 force case. | |
| Args: | |
| original_forces: Forces before rotation (N x 3 array) | |
| rotated_forces: Forces after rotation (N x 3 array) | |
| rotation_mat: 3 x 3 rotation matrix | |
| zero_threshold: Threshold below which forces are considered zero | |
| Returns: | |
| tuple containing: | |
| - mae: Mean absolute error between forces | |
| - cosine_similarity: Cosine similarity between force vectors | |
| """ | |
| rotated_original_forces = np.dot(original_forces, rotation_mat.T) | |
| force_diff = rotated_original_forces - rotated_forces | |
| mae = np.mean(np.abs(force_diff)) | |
| original_magnitudes = np.linalg.norm(rotated_original_forces, axis=1) | |
| rotated_magnitudes = np.linalg.norm(rotated_forces, axis=1) | |
| zero_original = original_magnitudes < zero_threshold | |
| zero_rotated = rotated_magnitudes < zero_threshold | |
| both_zero = zero_original & zero_rotated | |
| either_zero = zero_original | zero_rotated | |
| one_zero = either_zero & ~both_zero | |
| cosine_similarity = np.zeros(len(original_forces)) | |
| valid_forces = ~either_zero | |
| if np.any(valid_forces): | |
| norms_product = np.linalg.norm( | |
| rotated_original_forces[valid_forces], axis=1 | |
| ) * np.linalg.norm(rotated_forces[valid_forces], axis=1) | |
| dot_products = np.sum( | |
| rotated_original_forces[valid_forces] * rotated_forces[valid_forces], axis=1 | |
| ) | |
| cosine_similarity[valid_forces] = dot_products / norms_product | |
| # If both forces are 0, cosine similarity should be 1. If one is 0, we take the conservative -1. | |
| cosine_similarity[both_zero] = 1.0 | |
| cosine_similarity[one_zero] = -1.0 | |
| return mae, cosine_similarity | |
| def save_molecule_results( | |
| aggregate_results: dict, idx_list: np.ndarray, save_path: str | Path | |
| ) -> None: | |
| """ | |
| Save all molecule results from equivariance testing to .npy files. | |
| Save the index list of the atoms for further analysis. | |
| Args: | |
| aggregate_results: Dictionary containing the aggregated results from run() | |
| idx_list: List of the indices of the atoms in the original dataset | |
| save_path: Path to save the .npy files | |
| """ | |
| save_path = Path(save_path) | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| all_molecule_results = aggregate_results["molecule_results"] | |
| rotation_angles = list(all_molecule_results[0]["results_by_angle"].keys()) | |
| num_molecules = len(all_molecule_results) | |
| num_angles = len(rotation_angles) | |
| num_random_axes = len( | |
| all_molecule_results[0]["results_by_angle"][rotation_angles[0]]["maes"] | |
| ) | |
| num_atoms = len( | |
| all_molecule_results[0]["results_by_angle"][rotation_angles[0]][ | |
| "cosine_similarities" | |
| ][0] | |
| ) | |
| maes = np.zeros((num_molecules, num_angles, num_random_axes)) | |
| cosine_similarities = np.zeros((num_molecules, num_angles, num_random_axes)) | |
| for mol_idx, molecule in enumerate(all_molecule_results): | |
| for angle_idx, angle in enumerate(rotation_angles): | |
| angle_results = molecule["results_by_angle"][angle] | |
| maes[mol_idx, angle_idx, :] = angle_results["maes"] | |
| cosine_similarities[mol_idx, angle_idx, :] = np.mean( | |
| angle_results["cosine_similarities"], axis=-1 | |
| ) | |
| np.save(save_path.with_name(f"{save_path.stem}_maes.npy"), maes) | |
| np.save( | |
| save_path.with_name(f"{save_path.stem}_cosine_similarities.npy"), | |
| cosine_similarities, | |
| ) | |
| np.save(save_path.with_name(f"{save_path.stem}_idx_list.npy"), idx_list) | |
| def run( | |
| atoms_list: Sequence[Atoms], | |
| idx_list: np.ndarray, | |
| calculator: BaseCalculator, | |
| save_path: str | Path | None = None, | |
| rotation_angles: list[float] | np.ndarray = None, | |
| num_random_axes: int = 100, | |
| threshold: float = 1e-3, | |
| seed: int | None = None, | |
| ) -> dict: | |
| """ | |
| Test equivariance of force predictions under rotations for multiple structures. | |
| Args: | |
| atoms_list: List of input atomic structures | |
| idx_list: List of the indices of the atoms in the original dataset | |
| calculator: Calculator to use | |
| num_rotations: Number of random rotations to test | |
| rotation_angle: Angle of rotation in degrees | |
| threshold: Threshold for considering forces equivariant | |
| seed: Random seed | |
| Returns: | |
| Dictionary containing test results | |
| """ | |
| if seed is not None: | |
| np.random.seed(seed) | |
| if rotation_angles is None: | |
| rotation_angles = np.arange(30, 361, 30) | |
| rotation_angles = np.array(rotation_angles) | |
| all_results = [] | |
| cross_molecule_cosine_sims = {angle: [] for angle in rotation_angles} | |
| cross_molecule_mae = {angle: [] for angle in rotation_angles} | |
| rotation_axes = [generate_random_unit_vector() for _ in range(num_random_axes)] | |
| total_tests = len(atoms_list) * len(rotation_angles) * num_random_axes | |
| pbar = tqdm(total=total_tests, desc="Testing rotations") | |
| for atom_idx, atoms in enumerate(atoms_list): | |
| atoms = atoms.copy() | |
| atoms.calc = calculator | |
| original_forces = atoms.get_forces() | |
| results_by_angle = { | |
| angle: { | |
| "mae": [], | |
| "cosine_similarities": [], | |
| "passed_tests": 0, | |
| "passed_mae": 0, | |
| "passed_cosine_similarity": 0, | |
| } | |
| for angle in rotation_angles | |
| } | |
| # Test each angle with multiple random axes | |
| for angle in rotation_angles: | |
| for axis in rotation_axes: | |
| rotated_atoms, rotation_mat = rotate_molecule_arbitrary( | |
| atoms, angle, axis | |
| ) | |
| rotated_atoms.calc = calculator | |
| rotated_forces = rotated_atoms.get_forces() | |
| mae, cosine_similarity = compare_forces( | |
| original_forces, rotated_forces, rotation_mat | |
| ) | |
| results_by_angle[angle]["mae"].append(mae) | |
| results_by_angle[angle]["cosine_similarities"].append(cosine_similarity) | |
| cross_molecule_cosine_sims[angle].append( | |
| float(np.mean(cosine_similarity)) | |
| ) | |
| cross_molecule_mae[angle].append(float(np.mean(mae))) | |
| mae_check = mae < threshold | |
| cosine_check = all(cosine_similarity > (1 - threshold)) | |
| results_by_angle[angle]["passed_tests"] += int( | |
| mae_check and cosine_check | |
| ) | |
| results_by_angle[angle]["passed_mae"] += int(mae_check) | |
| results_by_angle[angle]["passed_cosine_similarity"] += int(cosine_check) | |
| pbar.update(1) | |
| # Compute summary statistics | |
| for angle in rotation_angles: | |
| results = results_by_angle[angle] | |
| results["mean_cosine_similarity"] = float( | |
| np.mean(results["cosine_similarities"]) | |
| ) | |
| results["avg_mae"] = float(np.mean(results["mae"])) | |
| results["equivariant_ratio"] = results["passed_tests"] / num_random_axes | |
| results["mae_passed_ratio"] = results["passed_mae"] / num_random_axes | |
| results["cosine_passed_ratio"] = ( | |
| results["passed_cosine_similarity"] / num_random_axes | |
| ) | |
| results["passed"] = results["passed_tests"] == num_random_axes | |
| results["passed_mae"] = results["passed_mae"] == num_random_axes | |
| results["passed_cosine_similarity"] = ( | |
| results["passed_cosine_similarity"] == num_random_axes | |
| ) | |
| results["maes"] = [float(x) for x in results["mae"]] | |
| results["cosine_similarities"] = [ | |
| [float(y) for y in x] for x in results["cosine_similarities"] | |
| ] | |
| molecule_results = { | |
| "mol_idx": idx_list[atom_idx], | |
| "results_by_angle": results_by_angle, | |
| "all_passed": all( | |
| results_by_angle[angle]["passed"] for angle in rotation_angles | |
| ), | |
| "avg_cosine_similarity_by_molecule": float( | |
| np.mean( | |
| [ | |
| results_by_angle[angle]["mean_cosine_similarity"] | |
| for angle in rotation_angles | |
| ] | |
| ) | |
| ), | |
| "avg_mae_by_molecule": float( | |
| np.mean( | |
| [results_by_angle[angle]["avg_mae"] for angle in rotation_angles] | |
| ) | |
| ), | |
| "overall_equivariant_ratio": float( | |
| np.mean( | |
| [ | |
| results_by_angle[angle]["equivariant_ratio"] | |
| for angle in rotation_angles | |
| ] | |
| ) | |
| ), | |
| } | |
| all_results.append(molecule_results) | |
| pbar.close() | |
| aggregate_results = { | |
| "num_molecules": len(atoms_list), | |
| "all_molecules_passed": all(result["all_passed"] for result in all_results), | |
| "average_equivariant_ratio": float( | |
| np.mean([result["overall_equivariant_ratio"] for result in all_results]) | |
| ), | |
| "average_cosine_similarity_by_angle": { | |
| angle: float(np.mean(sims)) | |
| for angle, sims in cross_molecule_cosine_sims.items() | |
| }, | |
| "average_mae_by_angle": { | |
| angle: float(np.mean(diffs)) for angle, diffs in cross_molecule_mae.items() | |
| }, | |
| "molecule_results": all_results, | |
| } | |
| if save_path: | |
| save_molecule_results(aggregate_results, idx_list, save_path) | |
| np.save( | |
| str(save_path.with_name(f"{save_path.stem}_molecule_results.npy")), | |
| all_results, | |
| ) | |
| return aggregate_results | |