File size: 2,862 Bytes
7efee70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

import torch
import numpy as np

from .utils import kabsch, aldp_diff, tic_diff


class Metric:
    def __init__(self, args, mds):
        self.device = args.device
        self.molecule = args.molecule
        self.save_dir = args.save_dir
        self.timestep = args.timestep
        self.friction = args.friction
        self.num_samples = args.num_samples
        self.m = mds.m
        self.std = mds.std
        self.log_prob = mds.log_prob
        self.heavy_atoms = mds.heavy_atoms
        self.energy_function = mds.energy_function
        self.target_position = mds.target_position

    def __call__(self):
        positions, forces, potentials = [], [], []
        for i in range(self.num_samples):
            position = np.load(f"{self.save_dir}/positions/{i}.npy").astype(np.float32)
            force, potential = self.energy_function(position)
            positions.append(torch.from_numpy(position).to(self.device))
            forces.append(torch.from_numpy(force).to(self.device))
            potentials.append(torch.from_numpy(potential).to(self.device))
        final_position = torch.stack([position[-1] for position in positions])
        rmsd, rmsd_std = self.rmsd(
            final_position[:, self.heavy_atoms],
            self.target_position[:, self.heavy_atoms],
        )
        thp, hit = self.thp(final_position, self.target_position)
        ets, ets_std = self.ets(hit, potentials)
        metrics = {
            "rmsd": 10 * rmsd,
            "thp": 100 * thp,
            "ets": ets,
            "rmsd_std": 10 * rmsd_std,
            "ets_std": ets_std,
        }
        return metrics

    def rmsd(self, position, target_position):
        R, t = kabsch(position, target_position)
        position = torch.matmul(position, R.transpose(-2, -1)) + t
        rmsd = (position - target_position).square().sum(-1).mean(-1).sqrt()
        rmsd, std_rmsd = rmsd.mean().item(), rmsd.std().item()
        return rmsd, std_rmsd

    def thp(self, position, target_position):
        if self.molecule == "aldp":
            psi_diff, phi_diff = aldp_diff(position, target_position)
            hit = psi_diff.square() + phi_diff.square() < 0.75 ** 2
        else:
            tic1_diff, tic2_diff = tic_diff(self.molecule, position, target_position)
            hit = tic1_diff.square() + tic2_diff.square() < 0.75 ** 2
        hit = hit.squeeze()
        thp = hit.sum().float() / len(hit)
        return thp.item(), hit

    def ets(self, hit, potentials):
        etss = []
        for i, hit_idx in enumerate(hit):
            if hit_idx:
                ets = potentials[i].max(0)[0]
                etss.append(ets)
        if len(etss) > 0:
            etss = torch.tensor(etss)
            ets, std_ets = etss.mean().item(), etss.std().item()
            return ets, std_ets
        else:
            return None, None