Sophia Tang
Initial commit with LFS
7efee70
import torch
import joblib
import mdtraj as md
import pyemma.coordinates as coor
def pairwise_dist(x):
dist_matrix = torch.cdist(x, x)
return dist_matrix
def kabsch(P, Q):
centroid_P = torch.mean(P, dim=-2, keepdims=True)
centroid_Q = torch.mean(Q, dim=-2, keepdims=True)
p = P - centroid_P
q = Q - centroid_Q
H = torch.matmul(p.transpose(-2, -1), q)
U, S, Vt = torch.linalg.svd(H)
d = torch.det(torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1)))
Vt[d < 0.0, -1] *= -1.0
R = torch.matmul(Vt.transpose(-2, -1), U.transpose(-2, -1))
t = centroid_Q - torch.matmul(centroid_P, R.transpose(-2, -1))
return R, t
# safe for gradient computation
def kabsch_safe(P: torch.Tensor, T: torch.Tensor):
"""
P, T: (..., N, 3)
Returns R (...,3,3), t (...,1,3)
"""
p_centroid = P.mean(dim=-2, keepdim=True)
t_centroid = T.mean(dim=-2, keepdim=True)
P0 = P - p_centroid
T0 = T - t_centroid
H = P0.transpose(-2, -1) @ T0 # (...,3,3)
U, S, Vh = torch.linalg.svd(H, full_matrices=False)
V = Vh.transpose(-2, -1) # out-of-place
Ut = U.transpose(-2, -1)
# Reflection correction WITHOUT any in-place on U/V/Vh
det = torch.det(V @ Ut) # (...,)
s = torch.where(det < 0, -1.0, 1.0) # (...,)
ones = torch.ones_like(s)
F = torch.diag_embed(torch.stack([ones, ones, s], dim=-1)) # (...,3,3)
R = V @ F @ Ut # (...,3,3)
t = t_centroid - p_centroid @ R.transpose(-2, -1)# (...,1,3)
return R, t
def compute_dihedral(positions):
v = positions[:, :-1] - positions[:, 1:]
v0 = -v[:, 0]
v1 = v[:, 2]
v2 = v[:, 1]
s0 = torch.sum(v0 * v2, dim=-1, keepdim=True) / torch.sum(
v2 * v2, dim=-1, keepdim=True
)
s1 = torch.sum(v1 * v2, dim=-1, keepdim=True) / torch.sum(
v2 * v2, dim=-1, keepdim=True
)
v0 = v0 - s0 * v2
v1 = v1 - s1 * v2
v0 = v0 / torch.norm(v0, dim=-1, keepdim=True)
v1 = v1 / torch.norm(v1, dim=-1, keepdim=True)
v2 = v2 / torch.norm(v2, dim=-1, keepdim=True)
x = torch.sum(v0 * v1, dim=-1)
v3 = torch.cross(v0, v2, dim=-1)
y = torch.sum(v3 * v1, dim=-1)
return torch.atan2(y, x)
def aldp_diff(position, target_position):
angle_2 = torch.tensor([1, 6, 8, 14], dtype=torch.long, device=position.device)
angle_1 = torch.tensor([6, 8, 14, 16], dtype=torch.long, device=position.device)
target_psi = compute_dihedral(target_position[:, angle_1])
target_phi = compute_dihedral(target_position[:, angle_2])
psi = compute_dihedral(position[:, angle_1])
phi = compute_dihedral(position[:, angle_2])
psi_diff = torch.abs(psi - target_psi) % (2 * torch.pi)
psi_diff = torch.min(psi_diff, 2 * torch.pi - psi_diff)
phi_diff = torch.abs(phi - target_phi) % (2 * torch.pi)
phi_diff = torch.min(phi_diff, 2 * torch.pi - phi_diff)
return psi_diff, phi_diff
def tic_diff(molecule, position, target_position):
tica_model = joblib.load(f"./data/{molecule}/tica_model.pkl")
feat = coor.featurizer(f"./data/{molecule}/folded.pdb")
feat.add_backbone_torsions(cossin=True)
traj = md.Trajectory(
target_position.cpu().numpy(),
md.load(f"./data/{molecule}/folded.pdb").topology,
)
feature = feat.transform(traj)
tica_target = tica_model.transform(feature)
tica_target = torch.from_numpy(tica_target).to(position.device)
traj = md.Trajectory(
position.cpu().numpy(),
md.load(f"./data/{molecule}/folded.pdb").topology,
)
feature = feat.transform(traj)
tica = tica_model.transform(feature)
tica = torch.from_numpy(tica).to(position.device)
tic1_diff = abs(tica[:, 0] - tica_target[:, 0])
tic2_diff = abs(tica[:, 1] - tica_target[:, 1])
return tic1_diff, tic2_diff