|
|
import torch |
|
|
from rfdiffusion.potentials import potentials as potentials |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def make_contact_matrix(nchain, intra_all=False, inter_all=False, contact_string=None): |
|
|
""" |
|
|
Calculate a matrix of inter/intra chain contact indicators |
|
|
|
|
|
Parameters: |
|
|
nchain (int, required): How many chains are in this design |
|
|
|
|
|
contact_str (str, optional): String denoting how to define contacts, comma delimited between pairs of chains |
|
|
'!' denotes repulsive, '&' denotes attractive |
|
|
""" |
|
|
alphabet = [a for a in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'] |
|
|
letter2num = {a:i for i,a in enumerate(alphabet)} |
|
|
|
|
|
contacts = np.zeros((nchain,nchain)) |
|
|
written = np.zeros((nchain,nchain)) |
|
|
|
|
|
|
|
|
|
|
|
if intra_all: |
|
|
contacts[np.arange(nchain),np.arange(nchain)] = 1 |
|
|
|
|
|
|
|
|
if inter_all: |
|
|
mask2d = np.full_like(contacts,False) |
|
|
for i in range(len(contacts)): |
|
|
for j in range(len(contacts)): |
|
|
if i!=j: |
|
|
mask2d[i,j] = True |
|
|
|
|
|
contacts[mask2d.astype(bool)] = 1 |
|
|
|
|
|
|
|
|
|
|
|
if contact_string != None: |
|
|
contact_list = contact_string.split(',') |
|
|
for c in contact_list: |
|
|
assert len(c) == 3 |
|
|
i,j = letter2num[c[0]],letter2num[c[2]] |
|
|
|
|
|
symbol = c[1] |
|
|
|
|
|
assert symbol in ['!','&'] |
|
|
if symbol == '!': |
|
|
contacts[i,j] = -1 |
|
|
contacts[j,i] = -1 |
|
|
else: |
|
|
contacts[i,j] = 1 |
|
|
contacts[j,i] = 1 |
|
|
|
|
|
return contacts |
|
|
|
|
|
|
|
|
def calc_nchains(symbol, components=1): |
|
|
""" |
|
|
Calculates number of chains for given symmetry |
|
|
""" |
|
|
S = symbol.lower() |
|
|
|
|
|
if S.startswith('c'): |
|
|
return int(S[1:])*components |
|
|
elif S.startswith('d'): |
|
|
return 2*int(S[1:])*components |
|
|
elif S.startswith('o'): |
|
|
raise NotImplementedError() |
|
|
elif S.startswith('t'): |
|
|
return 12*components |
|
|
else: |
|
|
raise RuntimeError('Unknown symmetry symbol ',S) |
|
|
|
|
|
|
|
|
class PotentialManager: |
|
|
''' |
|
|
Class to define a set of potentials from the given config object and to apply all of the specified potentials |
|
|
during each cycle of the inference loop. |
|
|
|
|
|
Author: NRB |
|
|
''' |
|
|
|
|
|
def __init__(self, |
|
|
potentials_config, |
|
|
ppi_config, |
|
|
diffuser_config, |
|
|
inference_config, |
|
|
hotspot_0idx, |
|
|
binderlen, |
|
|
): |
|
|
|
|
|
self.potentials_config = potentials_config |
|
|
self.ppi_config = ppi_config |
|
|
self.inference_config = inference_config |
|
|
|
|
|
self.guide_scale = potentials_config.guide_scale |
|
|
self.guide_decay = potentials_config.guide_decay |
|
|
|
|
|
if potentials_config.guiding_potentials is None: |
|
|
setting_list = [] |
|
|
else: |
|
|
setting_list = [self.parse_potential_string(potstr) for potstr in potentials_config.guiding_potentials] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if binderlen > 0: |
|
|
binderlen_update = { 'binderlen': binderlen } |
|
|
hotspot_res_update = { 'hotspot_res': hotspot_0idx } |
|
|
|
|
|
for setting in setting_list: |
|
|
if setting['type'] in potentials.require_binderlen: |
|
|
setting.update(binderlen_update) |
|
|
|
|
|
self.potentials_to_apply = self.initialize_all_potentials(setting_list) |
|
|
self.T = diffuser_config.T |
|
|
|
|
|
def is_empty(self): |
|
|
''' |
|
|
Check whether this instance of PotentialManager actually contains any potentials |
|
|
''' |
|
|
|
|
|
return len(self.potentials_to_apply) == 0 |
|
|
|
|
|
def parse_potential_string(self, potstr): |
|
|
''' |
|
|
Parse a single entry in the list of potentials to be run to a dictionary of settings for that potential. |
|
|
|
|
|
An example of how this parsing is done: |
|
|
'setting1:val1,setting2:val2,setting3:val3' -> {setting1:val1,setting2:val2,setting3:val3} |
|
|
''' |
|
|
|
|
|
setting_dict = {entry.split(':')[0]:entry.split(':')[1] for entry in potstr.split(',')} |
|
|
|
|
|
for key in setting_dict: |
|
|
if not key == 'type': setting_dict[key] = float(setting_dict[key]) |
|
|
|
|
|
return setting_dict |
|
|
|
|
|
def initialize_all_potentials(self, setting_list): |
|
|
''' |
|
|
Given a list of potential dictionaries where each dictionary defines the configurations for a single potential, |
|
|
initialize all potentials and add to the list of potentials to be applies |
|
|
''' |
|
|
|
|
|
to_apply = [] |
|
|
|
|
|
for potential_dict in setting_list: |
|
|
assert(potential_dict['type'] in potentials.implemented_potentials), f'potential with name: {potential_dict["type"]} is not one of the implemented potentials: {potentials.implemented_potentials.keys()}' |
|
|
|
|
|
kwargs = {k: potential_dict[k] for k in potential_dict.keys() - {'type'}} |
|
|
|
|
|
|
|
|
if self.inference_config.symmetry: |
|
|
|
|
|
num_chains = calc_nchains(symbol=self.inference_config.symmetry, components=1) |
|
|
contact_kwargs={'nchain':num_chains, |
|
|
'intra_all':self.potentials_config.olig_intra_all, |
|
|
'inter_all':self.potentials_config.olig_inter_all, |
|
|
'contact_string':self.potentials_config.olig_custom_contact } |
|
|
contact_matrix = make_contact_matrix(**contact_kwargs) |
|
|
kwargs.update({'contact_matrix':contact_matrix}) |
|
|
|
|
|
|
|
|
to_apply.append(potentials.implemented_potentials[potential_dict['type']](**kwargs)) |
|
|
|
|
|
return to_apply |
|
|
|
|
|
def compute_all_potentials(self, xyz): |
|
|
''' |
|
|
This is the money call. Take the current sequence and structure information and get the sum of all of the potentials that are being used |
|
|
''' |
|
|
|
|
|
potential_list = [potential.compute(xyz) for potential in self.potentials_to_apply] |
|
|
potential_stack = torch.stack(potential_list, dim=0) |
|
|
|
|
|
return torch.sum(potential_stack, dim=0) |
|
|
|
|
|
def get_guide_scale(self, t): |
|
|
''' |
|
|
Given a timestep and a decay type, get the appropriate scale factor to use for applying guiding potentials |
|
|
|
|
|
Inputs: |
|
|
|
|
|
t (int, required): The current timestep |
|
|
|
|
|
Output: |
|
|
|
|
|
scale (int): The scale factor to use for applying guiding potentials |
|
|
|
|
|
''' |
|
|
|
|
|
implemented_decay_types = { |
|
|
'constant': lambda t: self.guide_scale, |
|
|
|
|
|
'linear' : lambda t: t/self.T * self.guide_scale, |
|
|
'quadratic' : lambda t: t**2/self.T**2 * self.guide_scale, |
|
|
'cubic' : lambda t: t**3/self.T**3 * self.guide_scale |
|
|
} |
|
|
|
|
|
if self.guide_decay not in implemented_decay_types: |
|
|
sys.exit(f'decay_type must be one of {implemented_decay_types.keys()}. Received decay_type={self.guide_decay}. Exiting.') |
|
|
|
|
|
return implemented_decay_types[self.guide_decay](t) |
|
|
|
|
|
|
|
|
|
|
|
|