Sophia Tang
Initial commit with LFS
7efee70
import torch
import numpy as np
from .utils import kabsch, aldp_diff, tic_diff
class Metric:
def __init__(self, args, mds):
self.device = args.device
self.molecule = args.molecule
self.save_dir = args.save_dir
self.timestep = args.timestep
self.friction = args.friction
self.num_samples = args.num_samples
self.m = mds.m
self.std = mds.std
self.log_prob = mds.log_prob
self.heavy_atoms = mds.heavy_atoms
self.energy_function = mds.energy_function
self.target_position = mds.target_position
def __call__(self):
positions, forces, potentials = [], [], []
for i in range(self.num_samples):
position = np.load(f"{self.save_dir}/positions/{i}.npy").astype(np.float32)
force, potential = self.energy_function(position)
positions.append(torch.from_numpy(position).to(self.device))
forces.append(torch.from_numpy(force).to(self.device))
potentials.append(torch.from_numpy(potential).to(self.device))
final_position = torch.stack([position[-1] for position in positions])
rmsd, rmsd_std = self.rmsd(
final_position[:, self.heavy_atoms],
self.target_position[:, self.heavy_atoms],
)
thp, hit = self.thp(final_position, self.target_position)
ets, ets_std = self.ets(hit, potentials)
metrics = {
"rmsd": 10 * rmsd,
"thp": 100 * thp,
"ets": ets,
"rmsd_std": 10 * rmsd_std,
"ets_std": ets_std,
}
return metrics
def rmsd(self, position, target_position):
R, t = kabsch(position, target_position)
position = torch.matmul(position, R.transpose(-2, -1)) + t
rmsd = (position - target_position).square().sum(-1).mean(-1).sqrt()
rmsd, std_rmsd = rmsd.mean().item(), rmsd.std().item()
return rmsd, std_rmsd
def thp(self, position, target_position):
if self.molecule == "aldp":
psi_diff, phi_diff = aldp_diff(position, target_position)
hit = psi_diff.square() + phi_diff.square() < 0.75 ** 2
else:
tic1_diff, tic2_diff = tic_diff(self.molecule, position, target_position)
hit = tic1_diff.square() + tic2_diff.square() < 0.75 ** 2
hit = hit.squeeze()
thp = hit.sum().float() / len(hit)
return thp.item(), hit
def ets(self, hit, potentials):
etss = []
for i, hit_idx in enumerate(hit):
if hit_idx:
ets = potentials[i].max(0)[0]
etss.append(ets)
if len(etss) > 0:
etss = torch.tensor(etss)
ets, std_ets = etss.mean().item(), etss.std().item()
return ets, std_ets
else:
return None, None