|
|
import torch |
|
|
import numpy as np |
|
|
from rfdiffusion.util import generate_Cbeta |
|
|
|
|
|
class Potential: |
|
|
''' |
|
|
Interface class that defines the functions a potential must implement |
|
|
''' |
|
|
|
|
|
def compute(self, xyz): |
|
|
''' |
|
|
Given the current structure of the model prediction, return the current |
|
|
potential as a PyTorch tensor with a single entry |
|
|
|
|
|
Args: |
|
|
xyz (torch.tensor, size: [L,27,3]: The current coordinates of the sample |
|
|
|
|
|
Returns: |
|
|
potential (torch.tensor, size: [1]): A potential whose value will be MAXIMIZED |
|
|
by taking a step along it's gradient |
|
|
''' |
|
|
raise NotImplementedError('Potential compute function was not overwritten') |
|
|
|
|
|
class monomer_ROG(Potential): |
|
|
''' |
|
|
Radius of Gyration potential for encouraging monomer compactness |
|
|
|
|
|
Written by DJ and refactored into a class by NRB |
|
|
''' |
|
|
|
|
|
def __init__(self, weight=1, min_dist=15): |
|
|
|
|
|
self.weight = weight |
|
|
self.min_dist = min_dist |
|
|
|
|
|
def compute(self, xyz): |
|
|
Ca = xyz[:,1] |
|
|
|
|
|
centroid = torch.mean(Ca, dim=0, keepdim=True) |
|
|
|
|
|
dgram = torch.cdist(Ca[None,...].contiguous(), centroid[None,...].contiguous(), p=2) |
|
|
|
|
|
dgram = torch.maximum(self.min_dist * torch.ones_like(dgram.squeeze(0)), dgram.squeeze(0)) |
|
|
|
|
|
rad_of_gyration = torch.sqrt( torch.sum(torch.square(dgram)) / Ca.shape[0] ) |
|
|
|
|
|
return -1 * self.weight * rad_of_gyration |
|
|
|
|
|
class binder_ROG(Potential): |
|
|
''' |
|
|
Radius of Gyration potential for encouraging binder compactness |
|
|
|
|
|
Author: NRB |
|
|
''' |
|
|
|
|
|
def __init__(self, binderlen, weight=1, min_dist=15): |
|
|
|
|
|
self.binderlen = binderlen |
|
|
self.min_dist = min_dist |
|
|
self.weight = weight |
|
|
|
|
|
def compute(self, xyz): |
|
|
|
|
|
|
|
|
Ca = xyz[:self.binderlen,1] |
|
|
|
|
|
centroid = torch.mean(Ca, dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
dgram = torch.cdist(Ca[None,...].contiguous(), centroid[None,...].contiguous(), p=2) |
|
|
|
|
|
dgram = torch.maximum(self.min_dist * torch.ones_like(dgram.squeeze(0)), dgram.squeeze(0)) |
|
|
|
|
|
rad_of_gyration = torch.sqrt( torch.sum(torch.square(dgram)) / Ca.shape[0] ) |
|
|
|
|
|
return -1 * self.weight * rad_of_gyration |
|
|
|
|
|
|
|
|
class dimer_ROG(Potential): |
|
|
''' |
|
|
Radius of Gyration potential for encouraging compactness of both monomers when designing dimers |
|
|
|
|
|
Author: PV |
|
|
''' |
|
|
|
|
|
def __init__(self, binderlen, weight=1, min_dist=15): |
|
|
|
|
|
self.binderlen = binderlen |
|
|
self.min_dist = min_dist |
|
|
self.weight = weight |
|
|
|
|
|
def compute(self, xyz): |
|
|
|
|
|
|
|
|
Ca_m1 = xyz[:self.binderlen,1] |
|
|
|
|
|
|
|
|
Ca_m2 = xyz[self.binderlen:,1] |
|
|
|
|
|
centroid_m1 = torch.mean(Ca_m1, dim=0, keepdim=True) |
|
|
centroid_m2 = torch.mean(Ca_m1, dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
dgram_m1 = torch.cdist(Ca_m1[None,...].contiguous(), centroid_m1[None,...].contiguous(), p=2) |
|
|
dgram_m1 = torch.maximum(self.min_dist * torch.ones_like(dgram_m1.squeeze(0)), dgram_m1.squeeze(0)) |
|
|
rad_of_gyration_m1 = torch.sqrt( torch.sum(torch.square(dgram_m1)) / Ca_m1.shape[0] ) |
|
|
|
|
|
|
|
|
|
|
|
dgram_m2 = torch.cdist(Ca_m2[None,...].contiguous(), centroid_m2[None,...].contiguous(), p=2) |
|
|
dgram_m2 = torch.maximum(self.min_dist * torch.ones_like(dgram_m2.squeeze(0)), dgram_m2.squeeze(0)) |
|
|
rad_of_gyration_m2 = torch.sqrt( torch.sum(torch.square(dgram_m2)) / Ca_m2.shape[0] ) |
|
|
|
|
|
|
|
|
return -1 * self.weight * (rad_of_gyration_m1 + rad_of_gyration_m2)/2 |
|
|
|
|
|
class binder_ncontacts(Potential): |
|
|
''' |
|
|
Differentiable way to maximise number of contacts within a protein |
|
|
|
|
|
Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html |
|
|
|
|
|
''' |
|
|
|
|
|
def __init__(self, binderlen, weight=1, r_0=8, d_0=4): |
|
|
|
|
|
self.binderlen = binderlen |
|
|
self.r_0 = r_0 |
|
|
self.weight = weight |
|
|
self.d_0 = d_0 |
|
|
|
|
|
def compute(self, xyz): |
|
|
|
|
|
|
|
|
Ca = xyz[:self.binderlen,1] |
|
|
|
|
|
|
|
|
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) |
|
|
divide_by_r_0 = (dgram - self.d_0) / self.r_0 |
|
|
numerator = torch.pow(divide_by_r_0,6) |
|
|
denominator = torch.pow(divide_by_r_0,12) |
|
|
binder_ncontacts = (1 - numerator) / (1 - denominator) |
|
|
|
|
|
print("BINDER CONTACTS:", binder_ncontacts.sum()) |
|
|
|
|
|
return self.weight * binder_ncontacts.sum() |
|
|
|
|
|
class interface_ncontacts(Potential): |
|
|
|
|
|
''' |
|
|
Differentiable way to maximise number of contacts between binder and target |
|
|
|
|
|
Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html |
|
|
|
|
|
Author: PV |
|
|
''' |
|
|
|
|
|
|
|
|
def __init__(self, binderlen, weight=1, r_0=8, d_0=6): |
|
|
|
|
|
self.binderlen = binderlen |
|
|
self.r_0 = r_0 |
|
|
self.weight = weight |
|
|
self.d_0 = d_0 |
|
|
|
|
|
def compute(self, xyz): |
|
|
|
|
|
|
|
|
Ca_b = xyz[:self.binderlen,1] |
|
|
|
|
|
|
|
|
Ca_t = xyz[self.binderlen:,1] |
|
|
|
|
|
|
|
|
dgram = torch.cdist(Ca_b[None,...].contiguous(), Ca_t[None,...].contiguous(), p=2) |
|
|
divide_by_r_0 = (dgram - self.d_0) / self.r_0 |
|
|
numerator = torch.pow(divide_by_r_0,6) |
|
|
denominator = torch.pow(divide_by_r_0,12) |
|
|
interface_ncontacts = (1 - numerator) / (1 - denominator) |
|
|
|
|
|
interface_ncontacts = interface_ncontacts.sum() |
|
|
|
|
|
print("INTERFACE CONTACTS:", interface_ncontacts.sum()) |
|
|
|
|
|
return self.weight * interface_ncontacts |
|
|
|
|
|
|
|
|
class monomer_contacts(Potential): |
|
|
''' |
|
|
Differentiable way to maximise number of contacts within a protein |
|
|
|
|
|
Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html |
|
|
Author: PV |
|
|
|
|
|
NOTE: This function sometimes produces NaN's -- added check in reverse diffusion for nan grads |
|
|
''' |
|
|
|
|
|
def __init__(self, weight=1, r_0=8, d_0=2, eps=1e-6): |
|
|
|
|
|
self.r_0 = r_0 |
|
|
self.weight = weight |
|
|
self.d_0 = d_0 |
|
|
self.eps = eps |
|
|
|
|
|
def compute(self, xyz): |
|
|
|
|
|
Ca = xyz[:,1] |
|
|
|
|
|
|
|
|
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) |
|
|
divide_by_r_0 = (dgram - self.d_0) / self.r_0 |
|
|
numerator = torch.pow(divide_by_r_0,6) |
|
|
denominator = torch.pow(divide_by_r_0,12) |
|
|
|
|
|
ncontacts = (1 - numerator) / ((1 - denominator)) |
|
|
|
|
|
|
|
|
|
|
|
return self.weight * ncontacts.sum() |
|
|
|
|
|
|
|
|
class olig_contacts(Potential): |
|
|
""" |
|
|
Applies PV's num contacts potential within/between chains in symmetric oligomers |
|
|
|
|
|
Author: DJ |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
contact_matrix, |
|
|
weight_intra=1, |
|
|
weight_inter=1, |
|
|
r_0=8, d_0=2): |
|
|
""" |
|
|
Parameters: |
|
|
chain_lengths (list, required): List of chain lengths, length is (Nchains) |
|
|
|
|
|
contact_matrix (torch.tensor/np.array, required): |
|
|
square matrix of shape (Nchains,Nchains) whose (i,j) enry represents |
|
|
attractive (1), repulsive (-1), or non-existent (0) contact potentials |
|
|
between chains in the complex |
|
|
|
|
|
weight (int/float, optional): Scaling/weighting factor |
|
|
""" |
|
|
self.contact_matrix = contact_matrix |
|
|
self.weight_intra = weight_intra |
|
|
self.weight_inter = weight_inter |
|
|
self.r_0 = r_0 |
|
|
self.d_0 = d_0 |
|
|
|
|
|
|
|
|
assert all([i in [-1,0,1] for i in contact_matrix.flatten()]), 'Contact matrix must contain only 0, 1, or -1 in entries' |
|
|
|
|
|
shape = contact_matrix.shape |
|
|
assert len(shape) == 2 |
|
|
assert shape[0] == shape[1] |
|
|
for i in range(shape[0]): |
|
|
for j in range(shape[1]): |
|
|
assert contact_matrix[i,j] == contact_matrix[j,i] |
|
|
self.nchain=shape[0] |
|
|
|
|
|
|
|
|
def _get_idx(self,i,L): |
|
|
""" |
|
|
Returns the zero-indexed indices of the residues in chain i |
|
|
""" |
|
|
assert L%self.nchain == 0 |
|
|
Lchain = L//self.nchain |
|
|
return i*Lchain + torch.arange(Lchain) |
|
|
|
|
|
|
|
|
def compute(self, xyz): |
|
|
""" |
|
|
Iterate through the contact matrix, compute contact potentials between chains that need it, |
|
|
and negate contacts for any |
|
|
""" |
|
|
L = xyz.shape[0] |
|
|
|
|
|
all_contacts = 0 |
|
|
start = 0 |
|
|
for i in range(self.nchain): |
|
|
for j in range(self.nchain): |
|
|
|
|
|
if (i <= j) and (self.contact_matrix[i,j] != 0): |
|
|
|
|
|
|
|
|
idx_i = self._get_idx(i,L) |
|
|
idx_j = self._get_idx(j,L) |
|
|
|
|
|
Ca_i = xyz[idx_i,1] |
|
|
Ca_j = xyz[idx_j,1] |
|
|
dgram = torch.cdist(Ca_i[None,...].contiguous(), Ca_j[None,...].contiguous(), p=2) |
|
|
|
|
|
divide_by_r_0 = (dgram - self.d_0) / self.r_0 |
|
|
numerator = torch.pow(divide_by_r_0,6) |
|
|
denominator = torch.pow(divide_by_r_0,12) |
|
|
ncontacts = (1 - numerator) / (1 - denominator) |
|
|
|
|
|
|
|
|
scalar = (i==j)*self.weight_intra/2 + (i!=j)*self.weight_inter |
|
|
|
|
|
|
|
|
all_contacts += ncontacts.sum() * self.contact_matrix[i,j] * scalar |
|
|
|
|
|
return all_contacts |
|
|
|
|
|
def get_damped_lj(r_min, r_lin,p1=6,p2=12): |
|
|
|
|
|
y_at_r_lin = lj(r_lin, r_min, p1, p2) |
|
|
ydot_at_r_lin = lj_grad(r_lin, r_min,p1,p2) |
|
|
|
|
|
def inner(dgram): |
|
|
return (dgram < r_lin) * (ydot_at_r_lin * (dgram - r_lin) + y_at_r_lin) + (dgram >= r_lin) * lj(dgram, r_min, p1, p2) |
|
|
return inner |
|
|
|
|
|
def lj(dgram, r_min,p1=6, p2=12): |
|
|
return 4 * ((r_min / (2**(1/p1) * dgram))**p2 - (r_min / (2**(1/p1) * dgram))**p1) |
|
|
|
|
|
def lj_grad(dgram, r_min,p1=6,p2=12): |
|
|
return -p2 * r_min**p1*(r_min**p1-dgram**p1) / (dgram**(p2+1)) |
|
|
|
|
|
def mask_expand(mask, n=1): |
|
|
mask_out = mask.clone() |
|
|
assert mask.ndim == 1 |
|
|
for i in torch.where(mask)[0]: |
|
|
for j in range(i-n, i+n+1): |
|
|
if j >= 0 and j < len(mask): |
|
|
mask_out[j] = True |
|
|
return mask_out |
|
|
|
|
|
def contact_energy(dgram, d_0, r_0): |
|
|
divide_by_r_0 = (dgram - d_0) / r_0 |
|
|
numerator = torch.pow(divide_by_r_0,6) |
|
|
denominator = torch.pow(divide_by_r_0,12) |
|
|
|
|
|
ncontacts = (1 - numerator) / ((1 - denominator)).float() |
|
|
return - ncontacts |
|
|
|
|
|
def poly_repulse(dgram, r, slope, p=1): |
|
|
a = slope / (p * r**(p-1)) |
|
|
|
|
|
return (dgram < r) * a * torch.abs(r - dgram)**p * slope |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class substrate_contacts(Potential): |
|
|
''' |
|
|
Implicitly models a ligand with an attractive-repulsive potential. |
|
|
''' |
|
|
|
|
|
def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, rep_r_min=1): |
|
|
|
|
|
self.r_0 = r_0 |
|
|
self.weight = weight |
|
|
self.d_0 = d_0 |
|
|
self.eps = eps |
|
|
|
|
|
|
|
|
|
|
|
self.motif_frame = None |
|
|
self.motif_mapping = None |
|
|
self.motif_substrate_atoms = None |
|
|
r_min = 2 |
|
|
self.energies = [] |
|
|
self.energies.append(lambda dgram: s * contact_energy(torch.min(dgram, dim=-1)[0], d_0, r_0)) |
|
|
if rep_r_min: |
|
|
self.energies.append(lambda dgram: poly_repulse(torch.min(dgram, dim=-1)[0], rep_r_0, rep_s, p=1.5)) |
|
|
else: |
|
|
self.energies.append(lambda dgram: poly_repulse(dgram, rep_r_0, rep_s, p=1.5)) |
|
|
|
|
|
|
|
|
def compute(self, xyz): |
|
|
|
|
|
|
|
|
|
|
|
self._grab_motif_residues(self.xyz_motif) |
|
|
|
|
|
|
|
|
first_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(self.motif_substrate_atoms[0] - self.motif_frame[0]), dim=-1))) |
|
|
|
|
|
|
|
|
res = torch.tensor([k[0] for k in self.motif_mapping]) |
|
|
atoms = torch.tensor([k[1] for k in self.motif_mapping]) |
|
|
new_frame = xyz[self.diffusion_mask][res,atoms,:] |
|
|
|
|
|
A, t = self._recover_affine(self.motif_frame, new_frame) |
|
|
|
|
|
substrate_atoms = torch.mm(A, self.motif_substrate_atoms.transpose(0,1)).transpose(0,1) + t |
|
|
second_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(new_frame[0] - substrate_atoms[0]), dim=-1))) |
|
|
assert abs(first_distance - second_distance) < 0.01, "Alignment seems to be bad" |
|
|
diffusion_mask = mask_expand(self.diffusion_mask, 1) |
|
|
Ca = xyz[~diffusion_mask, 1] |
|
|
|
|
|
|
|
|
dgram = torch.cdist(Ca[None,...].contiguous(), substrate_atoms.float()[None], p=2)[0] |
|
|
|
|
|
all_energies = [] |
|
|
for i, energy_fn in enumerate(self.energies): |
|
|
energy = energy_fn(dgram) |
|
|
all_energies.append(energy.sum()) |
|
|
return - self.weight * sum(all_energies) |
|
|
|
|
|
|
|
|
return self.weight * ncontacts.sum() |
|
|
|
|
|
def _recover_affine(self,frame1, frame2): |
|
|
""" |
|
|
Uses Simplex Affine Matrix (SAM) formula to recover affine transform between two sets of 4 xyz coordinates |
|
|
See: https://www.researchgate.net/publication/332410209_Beginner%27s_guide_to_mapping_simplexes_affinely |
|
|
|
|
|
Args: |
|
|
frame1 - 4 coordinates from starting frame [4,3] |
|
|
frame2 - 4 coordinates from ending frame [4,3] |
|
|
|
|
|
Outputs: |
|
|
A - affine transformation matrix from frame1->frame2 |
|
|
t - affine translation vector from frame1->frame2 |
|
|
""" |
|
|
|
|
|
l = len(frame1) |
|
|
|
|
|
B = torch.vstack([frame1.T, torch.ones(l)]) |
|
|
D = 1.0 / torch.linalg.det(B) |
|
|
|
|
|
M = torch.zeros((3,4), dtype=torch.float64) |
|
|
for i, R in enumerate(frame2.T): |
|
|
for j in range(l): |
|
|
num = torch.vstack([R, B]) |
|
|
|
|
|
num = torch.cat((num[:j+1],num[j+2:])) |
|
|
|
|
|
M[i][j] = (-1)**j * D * torch.linalg.det(num) |
|
|
|
|
|
A, t = torch.hsplit(M, [l-1]) |
|
|
t = t.transpose(0,1) |
|
|
return A, t |
|
|
|
|
|
def _grab_motif_residues(self, xyz) -> None: |
|
|
""" |
|
|
Grabs 4 atoms in the motif. |
|
|
Currently random subset of Ca atoms if the motif is >= 4 residues, or else 4 random atoms from a single residue |
|
|
""" |
|
|
idx = torch.arange(self.diffusion_mask.shape[0]) |
|
|
idx = idx[self.diffusion_mask].float() |
|
|
if torch.sum(self.diffusion_mask) >= 4: |
|
|
rand_idx = torch.multinomial(idx, 4).long() |
|
|
|
|
|
self.motif_frame = xyz[rand_idx, 1] |
|
|
self.motif_mapping = [(i,1) for i in rand_idx] |
|
|
else: |
|
|
rand_idx = torch.multinomial(idx, 1).long() |
|
|
self.motif_frame = xyz[rand_idx[0],:4] |
|
|
self.motif_mapping = [(rand_idx, i) for i in range(4)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
implemented_potentials = { 'monomer_ROG': monomer_ROG, |
|
|
'binder_ROG': binder_ROG, |
|
|
'dimer_ROG': dimer_ROG, |
|
|
'binder_ncontacts': binder_ncontacts, |
|
|
'interface_ncontacts': interface_ncontacts, |
|
|
'monomer_contacts': monomer_contacts, |
|
|
'olig_contacts': olig_contacts, |
|
|
'substrate_contacts': substrate_contacts} |
|
|
|
|
|
require_binderlen = { 'binder_ROG', |
|
|
'binder_distance_ReLU', |
|
|
'binder_any_ReLU', |
|
|
'dimer_ROG', |
|
|
'binder_ncontacts', |
|
|
'interface_ncontacts'} |
|
|
|
|
|
|