|
|
import torch |
|
|
import joblib |
|
|
import mdtraj as md |
|
|
import pyemma.coordinates as coor |
|
|
|
|
|
|
|
|
def pairwise_dist(x): |
|
|
dist_matrix = torch.cdist(x, x) |
|
|
return dist_matrix |
|
|
|
|
|
|
|
|
def kabsch(P, Q): |
|
|
centroid_P = torch.mean(P, dim=-2, keepdims=True) |
|
|
centroid_Q = torch.mean(Q, dim=-2, keepdims=True) |
|
|
p = P - centroid_P |
|
|
q = Q - centroid_Q |
|
|
H = torch.matmul(p.transpose(-2, -1), q) |
|
|
U, S, Vt = torch.linalg.svd(H) |
|
|
d = torch.det(torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1))) |
|
|
Vt[d < 0.0, -1] *= -1.0 |
|
|
R = torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1)) |
|
|
t = centroid_Q - torch.matmul(centroid_P, R.transpose(-2, -1)) |
|
|
return R, t |
|
|
|
|
|
|
|
|
def kabsch_safe(P: torch.Tensor, T: torch.Tensor): |
|
|
""" |
|
|
P, T: (..., N, 3) |
|
|
Returns R (...,3,3), t (...,1,3) |
|
|
""" |
|
|
p_centroid = P.mean(dim=-2, keepdim=True) |
|
|
t_centroid = T.mean(dim=-2, keepdim=True) |
|
|
|
|
|
P0 = P - p_centroid |
|
|
T0 = T - t_centroid |
|
|
|
|
|
H = P0.transpose(-2, -1) @ T0 |
|
|
U, S, Vh = torch.linalg.svd(H, full_matrices=False) |
|
|
V = Vh.transpose(-2, -1) |
|
|
Ut = U.transpose(-2, -1) |
|
|
|
|
|
|
|
|
det = torch.det(V @ Ut) |
|
|
s = torch.where(det < 0, -1.0, 1.0) |
|
|
ones = torch.ones_like(s) |
|
|
F = torch.diag_embed(torch.stack([ones, ones, s], dim=-1)) |
|
|
|
|
|
R = V @ F @ Ut |
|
|
t = t_centroid - p_centroid @ R.transpose(-2, -1) |
|
|
return R, t |
|
|
|
|
|
def compute_dihedral(positions): |
|
|
v = positions[:, :-1] - positions[:, 1:] |
|
|
v0 = -v[:, 0] |
|
|
v1 = v[:, 2] |
|
|
v2 = v[:, 1] |
|
|
s0 = torch.sum(v0 * v2, dim=-1, keepdim=True) / torch.sum( |
|
|
v2 * v2, dim=-1, keepdim=True |
|
|
) |
|
|
s1 = torch.sum(v1 * v2, dim=-1, keepdim=True) / torch.sum( |
|
|
v2 * v2, dim=-1, keepdim=True |
|
|
) |
|
|
v0 = v0 - s0 * v2 |
|
|
v1 = v1 - s1 * v2 |
|
|
v0 = v0 / torch.norm(v0, dim=-1, keepdim=True) |
|
|
v1 = v1 / torch.norm(v1, dim=-1, keepdim=True) |
|
|
v2 = v2 / torch.norm(v2, dim=-1, keepdim=True) |
|
|
x = torch.sum(v0 * v1, dim=-1) |
|
|
v3 = torch.cross(v0, v2, dim=-1) |
|
|
y = torch.sum(v3 * v1, dim=-1) |
|
|
return torch.atan2(y, x) |
|
|
|
|
|
|
|
|
def aldp_diff(position, target_position): |
|
|
angle_2 = torch.tensor([1, 6, 8, 14], dtype=torch.long, device=position.device) |
|
|
angle_1 = torch.tensor([6, 8, 14, 16], dtype=torch.long, device=position.device) |
|
|
target_psi = compute_dihedral(target_position[:, angle_1]) |
|
|
target_phi = compute_dihedral(target_position[:, angle_2]) |
|
|
psi = compute_dihedral(position[:, angle_1]) |
|
|
phi = compute_dihedral(position[:, angle_2]) |
|
|
psi_diff = torch.abs(psi - target_psi) % (2 * torch.pi) |
|
|
psi_diff = torch.min(psi_diff, 2 * torch.pi - psi_diff) |
|
|
phi_diff = torch.abs(phi - target_phi) % (2 * torch.pi) |
|
|
phi_diff = torch.min(phi_diff, 2 * torch.pi - phi_diff) |
|
|
return psi_diff, phi_diff |
|
|
|
|
|
|
|
|
def tic_diff(molecule, position, target_position): |
|
|
tica_model = joblib.load(f"./data/{molecule}/tica_model.pkl") |
|
|
feat = coor.featurizer(f"./data/{molecule}/folded.pdb") |
|
|
feat.add_backbone_torsions(cossin=True) |
|
|
traj = md.Trajectory( |
|
|
target_position.cpu().numpy(), |
|
|
md.load(f"./data/{molecule}/folded.pdb").topology, |
|
|
) |
|
|
feature = feat.transform(traj) |
|
|
tica_target = tica_model.transform(feature) |
|
|
tica_target = torch.from_numpy(tica_target).to(position.device) |
|
|
traj = md.Trajectory( |
|
|
position.cpu().numpy(), |
|
|
md.load(f"./data/{molecule}/folded.pdb").topology, |
|
|
) |
|
|
feature = feat.transform(traj) |
|
|
tica = tica_model.transform(feature) |
|
|
tica = torch.from_numpy(tica).to(position.device) |
|
|
tic1_diff = abs(tica[:, 0] - tica_target[:, 0]) |
|
|
tic2_diff = abs(tica[:, 1] - tica_target[:, 1]) |
|
|
return tic1_diff, tic2_diff |
|
|
|