| import torch | |
| import numpy as np | |
| from .utils import kabsch, aldp_diff, tic_diff | |
| class Metric: | |
| def __init__(self, args, mds): | |
| self.device = args.device | |
| self.molecule = args.molecule | |
| self.save_dir = args.save_dir | |
| self.timestep = args.timestep | |
| self.friction = args.friction | |
| self.num_samples = args.num_samples | |
| self.m = mds.m | |
| self.std = mds.std | |
| self.log_prob = mds.log_prob | |
| self.heavy_atoms = mds.heavy_atoms | |
| self.energy_function = mds.energy_function | |
| self.target_position = mds.target_position | |
| def __call__(self): | |
| positions, forces, potentials = [], [], [] | |
| for i in range(self.num_samples): | |
| position = np.load(f"{self.save_dir}/positions/{i}.npy").astype(np.float32) | |
| force, potential = self.energy_function(position) | |
| positions.append(torch.from_numpy(position).to(self.device)) | |
| forces.append(torch.from_numpy(force).to(self.device)) | |
| potentials.append(torch.from_numpy(potential).to(self.device)) | |
| final_position = torch.stack([position[-1] for position in positions]) | |
| rmsd, rmsd_std = self.rmsd( | |
| final_position[:, self.heavy_atoms], | |
| self.target_position[:, self.heavy_atoms], | |
| ) | |
| thp, hit = self.thp(final_position, self.target_position) | |
| ets, ets_std = self.ets(hit, potentials) | |
| metrics = { | |
| "rmsd": 10 * rmsd, | |
| "thp": 100 * thp, | |
| "ets": ets, | |
| "rmsd_std": 10 * rmsd_std, | |
| "ets_std": ets_std, | |
| } | |
| return metrics | |
| def rmsd(self, position, target_position): | |
| R, t = kabsch(position, target_position) | |
| position = torch.matmul(position, R.transpose(-2, -1)) + t | |
| rmsd = (position - target_position).square().sum(-1).mean(-1).sqrt() | |
| rmsd, std_rmsd = rmsd.mean().item(), rmsd.std().item() | |
| return rmsd, std_rmsd | |
| def thp(self, position, target_position): | |
| if self.molecule == "aldp": | |
| psi_diff, phi_diff = aldp_diff(position, target_position) | |
| hit = psi_diff.square() + phi_diff.square() < 0.75 ** 2 | |
| else: | |
| tic1_diff, tic2_diff = tic_diff(self.molecule, position, target_position) | |
| hit = tic1_diff.square() + tic2_diff.square() < 0.75 ** 2 | |
| hit = hit.squeeze() | |
| thp = hit.sum().float() / len(hit) | |
| return thp.item(), hit | |
| def ets(self, hit, potentials): | |
| etss = [] | |
| for i, hit_idx in enumerate(hit): | |
| if hit_idx: | |
| ets = potentials[i].max(0)[0] | |
| etss.append(ets) | |
| if len(etss) > 0: | |
| etss = torch.tensor(etss) | |
| ets, std_ets = etss.mean().item(), etss.std().item() | |
| return ets, std_ets | |
| else: | |
| return None, None | |