GlandVergil's picture
Upload 686 files
3cdaa7d verified
import torch
import numpy as np
from rfdiffusion.util import generate_Cbeta
class Potential:
'''
Interface class that defines the functions a potential must implement
'''
def compute(self, xyz):
'''
Given the current structure of the model prediction, return the current
potential as a PyTorch tensor with a single entry
Args:
xyz (torch.tensor, size: [L,27,3]: The current coordinates of the sample
Returns:
potential (torch.tensor, size: [1]): A potential whose value will be MAXIMIZED
by taking a step along it's gradient
'''
raise NotImplementedError('Potential compute function was not overwritten')
class monomer_ROG(Potential):
'''
Radius of Gyration potential for encouraging monomer compactness
Written by DJ and refactored into a class by NRB
'''
def __init__(self, weight=1, min_dist=15):
self.weight = weight
self.min_dist = min_dist
def compute(self, xyz):
Ca = xyz[:,1] # [L,3]
centroid = torch.mean(Ca, dim=0, keepdim=True) # [1,3]
dgram = torch.cdist(Ca[None,...].contiguous(), centroid[None,...].contiguous(), p=2) # [1,L,1,3]
dgram = torch.maximum(self.min_dist * torch.ones_like(dgram.squeeze(0)), dgram.squeeze(0)) # [L,1,3]
rad_of_gyration = torch.sqrt( torch.sum(torch.square(dgram)) / Ca.shape[0] ) # [1]
return -1 * self.weight * rad_of_gyration
class binder_ROG(Potential):
'''
Radius of Gyration potential for encouraging binder compactness
Author: NRB
'''
def __init__(self, binderlen, weight=1, min_dist=15):
self.binderlen = binderlen
self.min_dist = min_dist
self.weight = weight
def compute(self, xyz):
# Only look at binder residues
Ca = xyz[:self.binderlen,1] # [Lb,3]
centroid = torch.mean(Ca, dim=0, keepdim=True) # [1,3]
# cdist needs a batch dimension - NRB
dgram = torch.cdist(Ca[None,...].contiguous(), centroid[None,...].contiguous(), p=2) # [1,Lb,1,3]
dgram = torch.maximum(self.min_dist * torch.ones_like(dgram.squeeze(0)), dgram.squeeze(0)) # [Lb,1,3]
rad_of_gyration = torch.sqrt( torch.sum(torch.square(dgram)) / Ca.shape[0] ) # [1]
return -1 * self.weight * rad_of_gyration
class dimer_ROG(Potential):
'''
Radius of Gyration potential for encouraging compactness of both monomers when designing dimers
Author: PV
'''
def __init__(self, binderlen, weight=1, min_dist=15):
self.binderlen = binderlen
self.min_dist = min_dist
self.weight = weight
def compute(self, xyz):
# Only look at monomer 1 residues
Ca_m1 = xyz[:self.binderlen,1] # [Lb,3]
# Only look at monomer 2 residues
Ca_m2 = xyz[self.binderlen:,1] # [Lb,3]
centroid_m1 = torch.mean(Ca_m1, dim=0, keepdim=True) # [1,3]
centroid_m2 = torch.mean(Ca_m1, dim=0, keepdim=True) # [1,3]
# cdist needs a batch dimension - NRB
#This calculates RoG for Monomer 1
dgram_m1 = torch.cdist(Ca_m1[None,...].contiguous(), centroid_m1[None,...].contiguous(), p=2) # [1,Lb,1,3]
dgram_m1 = torch.maximum(self.min_dist * torch.ones_like(dgram_m1.squeeze(0)), dgram_m1.squeeze(0)) # [Lb,1,3]
rad_of_gyration_m1 = torch.sqrt( torch.sum(torch.square(dgram_m1)) / Ca_m1.shape[0] ) # [1]
# cdist needs a batch dimension - NRB
#This calculates RoG for Monomer 2
dgram_m2 = torch.cdist(Ca_m2[None,...].contiguous(), centroid_m2[None,...].contiguous(), p=2) # [1,Lb,1,3]
dgram_m2 = torch.maximum(self.min_dist * torch.ones_like(dgram_m2.squeeze(0)), dgram_m2.squeeze(0)) # [Lb,1,3]
rad_of_gyration_m2 = torch.sqrt( torch.sum(torch.square(dgram_m2)) / Ca_m2.shape[0] ) # [1]
#Potential value is the average of both radii of gyration (is avg. the best way to do this?)
return -1 * self.weight * (rad_of_gyration_m1 + rad_of_gyration_m2)/2
class binder_ncontacts(Potential):
'''
Differentiable way to maximise number of contacts within a protein
Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html
'''
def __init__(self, binderlen, weight=1, r_0=8, d_0=4):
self.binderlen = binderlen
self.r_0 = r_0
self.weight = weight
self.d_0 = d_0
def compute(self, xyz):
# Only look at binder Ca residues
Ca = xyz[:self.binderlen,1] # [Lb,3]
#cdist needs a batch dimension - NRB
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb]
divide_by_r_0 = (dgram - self.d_0) / self.r_0
numerator = torch.pow(divide_by_r_0,6)
denominator = torch.pow(divide_by_r_0,12)
binder_ncontacts = (1 - numerator) / (1 - denominator)
print("BINDER CONTACTS:", binder_ncontacts.sum())
#Potential value is the average of both radii of gyration (is avg. the best way to do this?)
return self.weight * binder_ncontacts.sum()
class interface_ncontacts(Potential):
'''
Differentiable way to maximise number of contacts between binder and target
Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html
Author: PV
'''
def __init__(self, binderlen, weight=1, r_0=8, d_0=6):
self.binderlen = binderlen
self.r_0 = r_0
self.weight = weight
self.d_0 = d_0
def compute(self, xyz):
# Extract binder Ca residues
Ca_b = xyz[:self.binderlen,1] # [Lb,3]
# Extract target Ca residues
Ca_t = xyz[self.binderlen:,1] # [Lt,3]
#cdist needs a batch dimension - NRB
dgram = torch.cdist(Ca_b[None,...].contiguous(), Ca_t[None,...].contiguous(), p=2) # [1,Lb,Lt]
divide_by_r_0 = (dgram - self.d_0) / self.r_0
numerator = torch.pow(divide_by_r_0,6)
denominator = torch.pow(divide_by_r_0,12)
interface_ncontacts = (1 - numerator) / (1 - denominator)
#Potential is the sum of values in the tensor
interface_ncontacts = interface_ncontacts.sum()
print("INTERFACE CONTACTS:", interface_ncontacts.sum())
return self.weight * interface_ncontacts
class monomer_contacts(Potential):
'''
Differentiable way to maximise number of contacts within a protein
Motivation is given here: https://www.plumed.org/doc-v2.7/user-doc/html/_c_o_o_r_d_i_n_a_t_i_o_n.html
Author: PV
NOTE: This function sometimes produces NaN's -- added check in reverse diffusion for nan grads
'''
def __init__(self, weight=1, r_0=8, d_0=2, eps=1e-6):
self.r_0 = r_0
self.weight = weight
self.d_0 = d_0
self.eps = eps
def compute(self, xyz):
Ca = xyz[:,1] # [L,3]
#cdist needs a batch dimension - NRB
dgram = torch.cdist(Ca[None,...].contiguous(), Ca[None,...].contiguous(), p=2) # [1,Lb,Lb]
divide_by_r_0 = (dgram - self.d_0) / self.r_0
numerator = torch.pow(divide_by_r_0,6)
denominator = torch.pow(divide_by_r_0,12)
ncontacts = (1 - numerator) / ((1 - denominator))
#Potential value is the average of both radii of gyration (is avg. the best way to do this?)
return self.weight * ncontacts.sum()
class olig_contacts(Potential):
"""
Applies PV's num contacts potential within/between chains in symmetric oligomers
Author: DJ
"""
def __init__(self,
contact_matrix,
weight_intra=1,
weight_inter=1,
r_0=8, d_0=2):
"""
Parameters:
chain_lengths (list, required): List of chain lengths, length is (Nchains)
contact_matrix (torch.tensor/np.array, required):
square matrix of shape (Nchains,Nchains) whose (i,j) enry represents
attractive (1), repulsive (-1), or non-existent (0) contact potentials
between chains in the complex
weight (int/float, optional): Scaling/weighting factor
"""
self.contact_matrix = contact_matrix
self.weight_intra = weight_intra
self.weight_inter = weight_inter
self.r_0 = r_0
self.d_0 = d_0
# check contact matrix only contains valid entries
assert all([i in [-1,0,1] for i in contact_matrix.flatten()]), 'Contact matrix must contain only 0, 1, or -1 in entries'
# assert the matrix is square and symmetric
shape = contact_matrix.shape
assert len(shape) == 2
assert shape[0] == shape[1]
for i in range(shape[0]):
for j in range(shape[1]):
assert contact_matrix[i,j] == contact_matrix[j,i]
self.nchain=shape[0]
def _get_idx(self,i,L):
"""
Returns the zero-indexed indices of the residues in chain i
"""
assert L%self.nchain == 0
Lchain = L//self.nchain
return i*Lchain + torch.arange(Lchain)
def compute(self, xyz):
"""
Iterate through the contact matrix, compute contact potentials between chains that need it,
and negate contacts for any
"""
L = xyz.shape[0]
all_contacts = 0
start = 0
for i in range(self.nchain):
for j in range(self.nchain):
# only compute for upper triangle, disregard zeros in contact matrix
if (i <= j) and (self.contact_matrix[i,j] != 0):
# get the indices for these two chains
idx_i = self._get_idx(i,L)
idx_j = self._get_idx(j,L)
Ca_i = xyz[idx_i,1] # slice out crds for this chain
Ca_j = xyz[idx_j,1] # slice out crds for that chain
dgram = torch.cdist(Ca_i[None,...].contiguous(), Ca_j[None,...].contiguous(), p=2) # [1,Lb,Lb]
divide_by_r_0 = (dgram - self.d_0) / self.r_0
numerator = torch.pow(divide_by_r_0,6)
denominator = torch.pow(divide_by_r_0,12)
ncontacts = (1 - numerator) / (1 - denominator)
# weight, don't double count intra
scalar = (i==j)*self.weight_intra/2 + (i!=j)*self.weight_inter
# contacts attr/repuls relative weights
all_contacts += ncontacts.sum() * self.contact_matrix[i,j] * scalar
return all_contacts
def get_damped_lj(r_min, r_lin,p1=6,p2=12):
y_at_r_lin = lj(r_lin, r_min, p1, p2)
ydot_at_r_lin = lj_grad(r_lin, r_min,p1,p2)
def inner(dgram):
return (dgram < r_lin) * (ydot_at_r_lin * (dgram - r_lin) + y_at_r_lin) + (dgram >= r_lin) * lj(dgram, r_min, p1, p2)
return inner
def lj(dgram, r_min,p1=6, p2=12):
return 4 * ((r_min / (2**(1/p1) * dgram))**p2 - (r_min / (2**(1/p1) * dgram))**p1)
def lj_grad(dgram, r_min,p1=6,p2=12):
return -p2 * r_min**p1*(r_min**p1-dgram**p1) / (dgram**(p2+1))
def mask_expand(mask, n=1):
mask_out = mask.clone()
assert mask.ndim == 1
for i in torch.where(mask)[0]:
for j in range(i-n, i+n+1):
if j >= 0 and j < len(mask):
mask_out[j] = True
return mask_out
def contact_energy(dgram, d_0, r_0):
divide_by_r_0 = (dgram - d_0) / r_0
numerator = torch.pow(divide_by_r_0,6)
denominator = torch.pow(divide_by_r_0,12)
ncontacts = (1 - numerator) / ((1 - denominator)).float()
return - ncontacts
def poly_repulse(dgram, r, slope, p=1):
a = slope / (p * r**(p-1))
return (dgram < r) * a * torch.abs(r - dgram)**p * slope
#def only_top_n(dgram
class substrate_contacts(Potential):
'''
Implicitly models a ligand with an attractive-repulsive potential.
'''
def __init__(self, weight=1, r_0=8, d_0=2, s=1, eps=1e-6, rep_r_0=5, rep_s=2, rep_r_min=1):
self.r_0 = r_0
self.weight = weight
self.d_0 = d_0
self.eps = eps
# motif frame coordinates
# NOTE: these probably need to be set after sample_init() call, because the motif sequence position in design must be known
self.motif_frame = None # [4,3] xyz coordinates from 4 atoms of input motif
self.motif_mapping = None # list of tuples giving positions of above atoms in design [(resi, atom_idx)]
self.motif_substrate_atoms = None # xyz coordinates of substrate from input motif
r_min = 2
self.energies = []
self.energies.append(lambda dgram: s * contact_energy(torch.min(dgram, dim=-1)[0], d_0, r_0))
if rep_r_min:
self.energies.append(lambda dgram: poly_repulse(torch.min(dgram, dim=-1)[0], rep_r_0, rep_s, p=1.5))
else:
self.energies.append(lambda dgram: poly_repulse(dgram, rep_r_0, rep_s, p=1.5))
def compute(self, xyz):
# First, get random set of atoms
# This operates on self.xyz_motif, which is assigned to this class in the model runner (for horrible plumbing reasons)
self._grab_motif_residues(self.xyz_motif)
# for checking affine transformation is corect
first_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(self.motif_substrate_atoms[0] - self.motif_frame[0]), dim=-1)))
# grab the coordinates of the corresponding atoms in the new frame using mapping
res = torch.tensor([k[0] for k in self.motif_mapping])
atoms = torch.tensor([k[1] for k in self.motif_mapping])
new_frame = xyz[self.diffusion_mask][res,atoms,:]
# calculate affine transformation matrix and translation vector b/w new frame and motif frame
A, t = self._recover_affine(self.motif_frame, new_frame)
# apply affine transformation to substrate atoms
substrate_atoms = torch.mm(A, self.motif_substrate_atoms.transpose(0,1)).transpose(0,1) + t
second_distance = torch.sqrt(torch.sqrt(torch.sum(torch.square(new_frame[0] - substrate_atoms[0]), dim=-1)))
assert abs(first_distance - second_distance) < 0.01, "Alignment seems to be bad"
diffusion_mask = mask_expand(self.diffusion_mask, 1)
Ca = xyz[~diffusion_mask, 1]
#cdist needs a batch dimension - NRB
dgram = torch.cdist(Ca[None,...].contiguous(), substrate_atoms.float()[None], p=2)[0] # [Lb,Lb]
all_energies = []
for i, energy_fn in enumerate(self.energies):
energy = energy_fn(dgram)
all_energies.append(energy.sum())
return - self.weight * sum(all_energies)
#Potential value is the average of both radii of gyration (is avg. the best way to do this?)
return self.weight * ncontacts.sum()
def _recover_affine(self,frame1, frame2):
"""
Uses Simplex Affine Matrix (SAM) formula to recover affine transform between two sets of 4 xyz coordinates
See: https://www.researchgate.net/publication/332410209_Beginner%27s_guide_to_mapping_simplexes_affinely
Args:
frame1 - 4 coordinates from starting frame [4,3]
frame2 - 4 coordinates from ending frame [4,3]
Outputs:
A - affine transformation matrix from frame1->frame2
t - affine translation vector from frame1->frame2
"""
l = len(frame1)
# construct SAM denominator matrix
B = torch.vstack([frame1.T, torch.ones(l)])
D = 1.0 / torch.linalg.det(B) # SAM denominator
M = torch.zeros((3,4), dtype=torch.float64)
for i, R in enumerate(frame2.T):
for j in range(l):
num = torch.vstack([R, B])
# make SAM numerator matrix
num = torch.cat((num[:j+1],num[j+2:])) # make numerator matrix
# calculate SAM entry
M[i][j] = (-1)**j * D * torch.linalg.det(num)
A, t = torch.hsplit(M, [l-1])
t = t.transpose(0,1)
return A, t
def _grab_motif_residues(self, xyz) -> None:
"""
Grabs 4 atoms in the motif.
Currently random subset of Ca atoms if the motif is >= 4 residues, or else 4 random atoms from a single residue
"""
idx = torch.arange(self.diffusion_mask.shape[0])
idx = idx[self.diffusion_mask].float()
if torch.sum(self.diffusion_mask) >= 4:
rand_idx = torch.multinomial(idx, 4).long()
# get Ca atoms
self.motif_frame = xyz[rand_idx, 1]
self.motif_mapping = [(i,1) for i in rand_idx]
else:
rand_idx = torch.multinomial(idx, 1).long()
self.motif_frame = xyz[rand_idx[0],:4]
self.motif_mapping = [(rand_idx, i) for i in range(4)]
# Dictionary of types of potentials indexed by name of potential. Used by PotentialManager.
# If you implement a new potential you must add it to this dictionary for it to be used by
# the PotentialManager
implemented_potentials = { 'monomer_ROG': monomer_ROG,
'binder_ROG': binder_ROG,
'dimer_ROG': dimer_ROG,
'binder_ncontacts': binder_ncontacts,
'interface_ncontacts': interface_ncontacts,
'monomer_contacts': monomer_contacts,
'olig_contacts': olig_contacts,
'substrate_contacts': substrate_contacts}
require_binderlen = { 'binder_ROG',
'binder_distance_ReLU',
'binder_any_ReLU',
'dimer_ROG',
'binder_ncontacts',
'interface_ncontacts'}