File size: 3,898 Bytes
7efee70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import joblib
import mdtraj as md
import pyemma.coordinates as coor


def pairwise_dist(x):
    dist_matrix = torch.cdist(x, x)
    return dist_matrix


def kabsch(P, Q):
    centroid_P = torch.mean(P, dim=-2, keepdims=True)
    centroid_Q = torch.mean(Q, dim=-2, keepdims=True)
    p = P - centroid_P
    q = Q - centroid_Q
    H = torch.matmul(p.transpose(-2, -1), q)
    U, S, Vt = torch.linalg.svd(H)
    d = torch.det(torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1)))
    Vt[d < 0.0, -1] *= -1.0
    R = torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1))
    t = centroid_Q - torch.matmul(centroid_P, R.transpose(-2, -1))
    return R, t

# safe for gradient computation
def kabsch_safe(P: torch.Tensor, T: torch.Tensor):
    """
    P, T: (..., N, 3)
    Returns R (...,3,3), t (...,1,3)
    """
    p_centroid = P.mean(dim=-2, keepdim=True)
    t_centroid = T.mean(dim=-2, keepdim=True)

    P0 = P - p_centroid
    T0 = T - t_centroid

    H = P0.transpose(-2, -1) @ T0                   # (...,3,3)
    U, S, Vh = torch.linalg.svd(H, full_matrices=False)
    V  = Vh.transpose(-2, -1)                       # out-of-place
    Ut = U.transpose(-2, -1)

    # Reflection correction WITHOUT any in-place on U/V/Vh
    det = torch.det(V @ Ut)                         # (...,)
    s = torch.where(det < 0, -1.0, 1.0)             # (...,)
    ones = torch.ones_like(s)
    F = torch.diag_embed(torch.stack([ones, ones, s], dim=-1))  # (...,3,3)

    R = V @ F @ Ut                                   # (...,3,3)
    t = t_centroid - p_centroid @ R.transpose(-2, -1)# (...,1,3)
    return R, t

def compute_dihedral(positions):
    v = positions[:, :-1] - positions[:, 1:]
    v0 = -v[:, 0]
    v1 = v[:, 2]
    v2 = v[:, 1]
    s0 = torch.sum(v0 * v2, dim=-1, keepdim=True) / torch.sum(
        v2 * v2, dim=-1, keepdim=True
    )
    s1 = torch.sum(v1 * v2, dim=-1, keepdim=True) / torch.sum(
        v2 * v2, dim=-1, keepdim=True
    )
    v0 = v0 - s0 * v2
    v1 = v1 - s1 * v2
    v0 = v0 / torch.norm(v0, dim=-1, keepdim=True)
    v1 = v1 / torch.norm(v1, dim=-1, keepdim=True)
    v2 = v2 / torch.norm(v2, dim=-1, keepdim=True)
    x = torch.sum(v0 * v1, dim=-1)
    v3 = torch.cross(v0, v2, dim=-1)
    y = torch.sum(v3 * v1, dim=-1)
    return torch.atan2(y, x)


def aldp_diff(position, target_position):
    angle_2 = torch.tensor([1, 6, 8, 14], dtype=torch.long, device=position.device)
    angle_1 = torch.tensor([6, 8, 14, 16], dtype=torch.long, device=position.device)
    target_psi = compute_dihedral(target_position[:, angle_1])
    target_phi = compute_dihedral(target_position[:, angle_2])
    psi = compute_dihedral(position[:, angle_1])
    phi = compute_dihedral(position[:, angle_2])
    psi_diff = torch.abs(psi - target_psi) % (2 * torch.pi)
    psi_diff = torch.min(psi_diff, 2 * torch.pi - psi_diff)
    phi_diff = torch.abs(phi - target_phi) % (2 * torch.pi)
    phi_diff = torch.min(phi_diff, 2 * torch.pi - phi_diff)
    return psi_diff, phi_diff


def tic_diff(molecule, position, target_position):
    tica_model = joblib.load(f"./data/{molecule}/tica_model.pkl")
    feat = coor.featurizer(f"./data/{molecule}/folded.pdb")
    feat.add_backbone_torsions(cossin=True)
    traj = md.Trajectory(
        target_position.cpu().numpy(),
        md.load(f"./data/{molecule}/folded.pdb").topology,
    )
    feature = feat.transform(traj)
    tica_target = tica_model.transform(feature)
    tica_target = torch.from_numpy(tica_target).to(position.device)
    traj = md.Trajectory(
        position.cpu().numpy(),
        md.load(f"./data/{molecule}/folded.pdb").topology,
    )
    feature = feat.transform(traj)
    tica = tica_model.transform(feature)
    tica = torch.from_numpy(tica).to(position.device)
    tic1_diff = abs(tica[:, 0] - tica_target[:, 0])
    tic2_diff = abs(tica[:, 1] - tica_target[:, 1])
    return tic1_diff, tic2_diff