import torch import numpy as np from .utils import kabsch, aldp_diff, tic_diff class Metric: def __init__(self, args, mds): self.device = args.device self.molecule = args.molecule self.save_dir = args.save_dir self.timestep = args.timestep self.friction = args.friction self.num_samples = args.num_samples self.m = mds.m self.std = mds.std self.log_prob = mds.log_prob self.heavy_atoms = mds.heavy_atoms self.energy_function = mds.energy_function self.target_position = mds.target_position def __call__(self): positions, forces, potentials = [], [], [] for i in range(self.num_samples): position = np.load(f"{self.save_dir}/positions/{i}.npy").astype(np.float32) force, potential = self.energy_function(position) positions.append(torch.from_numpy(position).to(self.device)) forces.append(torch.from_numpy(force).to(self.device)) potentials.append(torch.from_numpy(potential).to(self.device)) final_position = torch.stack([position[-1] for position in positions]) rmsd, rmsd_std = self.rmsd( final_position[:, self.heavy_atoms], self.target_position[:, self.heavy_atoms], ) thp, hit = self.thp(final_position, self.target_position) ets, ets_std = self.ets(hit, potentials) metrics = { "rmsd": 10 * rmsd, "thp": 100 * thp, "ets": ets, "rmsd_std": 10 * rmsd_std, "ets_std": ets_std, } return metrics def rmsd(self, position, target_position): R, t = kabsch(position, target_position) position = torch.matmul(position, R.transpose(-2, -1)) + t rmsd = (position - target_position).square().sum(-1).mean(-1).sqrt() rmsd, std_rmsd = rmsd.mean().item(), rmsd.std().item() return rmsd, std_rmsd def thp(self, position, target_position): if self.molecule == "aldp": psi_diff, phi_diff = aldp_diff(position, target_position) hit = psi_diff.square() + phi_diff.square() < 0.75 ** 2 else: tic1_diff, tic2_diff = tic_diff(self.molecule, position, target_position) hit = tic1_diff.square() + tic2_diff.square() < 0.75 ** 2 hit = hit.squeeze() thp = hit.sum().float() / len(hit) return thp.item(), hit def ets(self, hit, potentials): etss = [] for i, hit_idx in enumerate(hit): if hit_idx: ets = potentials[i].max(0)[0] etss.append(ets) if len(etss) > 0: etss = torch.tensor(etss) ets, std_ets = etss.mean().item(), etss.std().item() return ets, std_ets else: return None, None