| from abc import ABC, abstractmethod |
| from typing import Optional, Dict, Any, Set, List, Union |
|
|
| import torch |
| import numpy as np |
| from . import vb_const as const |
| from .vb_potentials_schedules import ( |
| ParameterSchedule, |
| ExponentialInterpolation, |
| PiecewiseStepFunction, |
| ) |
| from .vb_loss_diffusionv2 import weighted_rigid_align |
|
|
|
|
| class Potential(ABC): |
| def __init__( |
| self, |
| parameters: Optional[ |
| Dict[str, Union[ParameterSchedule, float, int, bool]] |
| ] = None, |
| ): |
| self.parameters = parameters |
|
|
| def compute(self, coords, feats, parameters): |
| index, args, com_args, ref_args, operator_args = self.compute_args( |
| feats, parameters |
| ) |
|
|
| if index.shape[1] == 0: |
| return torch.zeros(coords.shape[:-2], device=coords.device) |
|
|
| if com_args is not None: |
| com_index, atom_pad_mask = com_args |
| unpad_com_index = com_index[atom_pad_mask] |
| unpad_coords = coords[..., atom_pad_mask, :] |
| coords = torch.zeros( |
| (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), |
| device=coords.device, |
| ).scatter_reduce( |
| -2, |
| unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), |
| unpad_coords, |
| "mean", |
| ) |
| else: |
| com_index, atom_pad_mask = None, None |
|
|
| if ref_args is not None: |
| ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args |
| coords = coords[..., ref_atom_index, :] |
| else: |
| ref_coords, ref_mask, ref_atom_index, ref_token_index = ( |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
| if operator_args is not None: |
| negation_mask, union_index = operator_args |
| else: |
| negation_mask, union_index = None, None |
|
|
| value = self.compute_variable( |
| coords, |
| index, |
| ref_coords=ref_coords, |
| ref_mask=ref_mask, |
| compute_gradient=False, |
| ) |
| energy = self.compute_function( |
| value, *args, negation_mask=negation_mask, compute_derivative=False |
| ) |
|
|
| if union_index is not None: |
| neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) |
| Z = torch.zeros( |
| (*energy.shape[:-1], union_index.max() + 1), device=union_index.device |
| ).scatter_reduce( |
| -1, |
| union_index.expand_as(neg_exp_energy), |
| neg_exp_energy, |
| "sum", |
| ) |
| softmax_energy = neg_exp_energy / Z[..., union_index] |
| softmax_energy[Z[..., union_index] == 0] = 0 |
| return (energy * softmax_energy).sum(dim=-1) |
|
|
| return energy.sum(dim=tuple(range(1, energy.dim()))) |
|
|
| def compute_gradient(self, coords, feats, parameters): |
| index, args, com_args, ref_args, operator_args = self.compute_args( |
| feats, parameters |
| ) |
| if index.shape[1] == 0: |
| return torch.zeros_like(coords) |
|
|
| if com_args is not None: |
| com_index, atom_pad_mask = com_args |
| unpad_coords = coords[..., atom_pad_mask, :] |
| unpad_com_index = com_index[atom_pad_mask] |
| coords = torch.zeros( |
| (*unpad_coords.shape[:-2], unpad_com_index.max() + 1, 3), |
| device=coords.device, |
| ).scatter_reduce( |
| -2, |
| unpad_com_index.unsqueeze(-1).expand_as(unpad_coords), |
| unpad_coords, |
| "mean", |
| ) |
| com_counts = torch.bincount(com_index[atom_pad_mask]) |
| else: |
| com_index, atom_pad_mask = None, None |
|
|
| if ref_args is not None: |
| ref_coords, ref_mask, ref_atom_index, ref_token_index = ref_args |
| coords = coords[..., ref_atom_index, :] |
| else: |
| ref_coords, ref_mask, ref_atom_index, ref_token_index = ( |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
| if operator_args is not None: |
| negation_mask, union_index = operator_args |
| else: |
| negation_mask, union_index = None, None |
|
|
| value, grad_value = self.compute_variable( |
| coords, |
| index, |
| ref_coords=ref_coords, |
| ref_mask=ref_mask, |
| compute_gradient=True, |
| ) |
| energy, dEnergy = self.compute_function( |
| value, |
| *args, negation_mask=negation_mask, compute_derivative=True |
| ) |
| if union_index is not None: |
| neg_exp_energy = torch.exp(-1 * parameters["union_lambda"] * energy) |
| Z = torch.zeros( |
| (*energy.shape[:-1], union_index.max() + 1), device=union_index.device |
| ).scatter_reduce( |
| -1, |
| union_index.expand_as(energy), |
| neg_exp_energy, |
| "sum", |
| ) |
| softmax_energy = neg_exp_energy / Z[..., union_index] |
| softmax_energy[Z[..., union_index] == 0] = 0 |
| f = torch.zeros( |
| (*energy.shape[:-1], union_index.max() + 1), device=union_index.device |
| ).scatter_reduce( |
| -1, |
| union_index.expand_as(energy), |
| energy * softmax_energy, |
| "sum", |
| ) |
| dSoftmax = ( |
| dEnergy |
| * softmax_energy |
| * (1 + parameters["union_lambda"] * (energy - f[..., union_index])) |
| ) |
| prod = dSoftmax.tile(grad_value.shape[-3]).unsqueeze( |
| -1 |
| ) * grad_value.flatten(start_dim=-3, end_dim=-2) |
| if prod.dim() > 3: |
| prod = prod.sum(dim=list(range(1, prod.dim() - 2))) |
| grad_atom = torch.zeros_like(coords).scatter_reduce( |
| -2, |
| index.flatten(start_dim=0, end_dim=1) |
| .unsqueeze(-1) |
| .expand((*coords.shape[:-2], -1, 3)), |
| prod, |
| "sum", |
| ) |
| else: |
| prod = dEnergy.tile(grad_value.shape[-3]).unsqueeze( |
| -1 |
| ) * grad_value.flatten(start_dim=-3, end_dim=-2) |
| if prod.dim() > 3: |
| prod = prod.sum(dim=list(range(1, prod.dim() - 2))) |
| grad_atom = torch.zeros_like(coords).scatter_reduce( |
| -2, |
| index.flatten(start_dim=0, end_dim=1) |
| .unsqueeze(-1) |
| .expand((*coords.shape[:-2], -1, 3)), |
| prod, |
| "sum", |
| ) |
|
|
| if com_index is not None: |
| grad_atom = grad_atom[..., com_index, :] |
| elif ref_token_index is not None: |
| grad_atom = grad_atom[..., ref_token_index, :] |
|
|
| return grad_atom |
|
|
| def compute_parameters(self, t): |
| if self.parameters is None: |
| return None |
| parameters = { |
| name: parameter |
| if not isinstance(parameter, ParameterSchedule) |
| else parameter.compute(t) |
| for name, parameter in self.parameters.items() |
| } |
| return parameters |
|
|
| @abstractmethod |
| def compute_function( |
| self, value, *args, negation_mask=None, compute_derivative=False |
| ): |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def compute_variable(self, coords, index, compute_gradient=False): |
| raise NotImplementedError |
|
|
| @abstractmethod |
| def compute_args(self, t, feats, **parameters): |
| raise NotImplementedError |
|
|
| def get_reference_coords(self, feats, parameters): |
| return None, None |
|
|
|
|
| class FlatBottomPotential(Potential): |
| def compute_function( |
| self, |
| value, |
| k, |
| lower_bounds, |
| upper_bounds, |
| negation_mask=None, |
| compute_derivative=False, |
| ): |
| if lower_bounds is None: |
| lower_bounds = torch.full_like(value, float("-inf")) |
| if upper_bounds is None: |
| upper_bounds = torch.full_like(value, float("inf")) |
| lower_bounds = lower_bounds.expand_as(value).clone() |
| upper_bounds = upper_bounds.expand_as(value).clone() |
|
|
| if negation_mask is not None: |
| unbounded_below_mask = torch.isneginf(lower_bounds) |
| unbounded_above_mask = torch.isposinf(upper_bounds) |
| unbounded_mask = unbounded_below_mask + unbounded_above_mask |
| assert torch.all(unbounded_mask + negation_mask) |
| lower_bounds[~unbounded_above_mask * ~negation_mask] = upper_bounds[ |
| ~unbounded_above_mask * ~negation_mask |
| ] |
| upper_bounds[~unbounded_above_mask * ~negation_mask] = float("inf") |
| upper_bounds[~unbounded_below_mask * ~negation_mask] = lower_bounds[ |
| ~unbounded_below_mask * ~negation_mask |
| ] |
| lower_bounds[~unbounded_below_mask * ~negation_mask] = float("-inf") |
|
|
| neg_overflow_mask = value < lower_bounds |
| pos_overflow_mask = value > upper_bounds |
|
|
| energy = torch.zeros_like(value) |
| energy[neg_overflow_mask] = (k * (lower_bounds - value))[neg_overflow_mask] |
| energy[pos_overflow_mask] = (k * (value - upper_bounds))[pos_overflow_mask] |
| if not compute_derivative: |
| return energy |
|
|
| dEnergy = torch.zeros_like(value) |
| dEnergy[neg_overflow_mask] = ( |
| -1 * k.expand_as(neg_overflow_mask)[neg_overflow_mask] |
| ) |
| dEnergy[pos_overflow_mask] = ( |
| 1 * k.expand_as(pos_overflow_mask)[pos_overflow_mask] |
| ) |
|
|
| return energy, dEnergy |
|
|
|
|
| class ReferencePotential(Potential): |
| def compute_variable( |
| self, coords, index, ref_coords, ref_mask, compute_gradient=False |
| ): |
| aligned_ref_coords = weighted_rigid_align( |
| ref_coords.float(), |
| coords[:, index].float(), |
| ref_mask, |
| ref_mask, |
| ) |
|
|
| r = coords[:, index] - aligned_ref_coords |
| r_norm = torch.linalg.norm(r, dim=-1) |
|
|
| if not compute_gradient: |
| return r_norm |
|
|
| r_hat = r / r_norm.unsqueeze(-1) |
| grad = (r_hat * ref_mask.unsqueeze(-1)).unsqueeze(1) |
| return r_norm, grad |
|
|
|
|
| class DistancePotential(Potential): |
| def compute_variable( |
| self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False |
| ): |
| r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1]) |
| r_ij_norm = torch.linalg.norm(r_ij, dim=-1) |
| r_hat_ij = r_ij / r_ij_norm.unsqueeze(-1) |
|
|
| if not compute_gradient: |
| return r_ij_norm |
|
|
| grad_i = r_hat_ij |
| grad_j = -1 * r_hat_ij |
| grad = torch.stack((grad_i, grad_j), dim=1) |
| return r_ij_norm, grad |
|
|
|
|
| class DihedralPotential(Potential): |
| def compute_variable( |
| self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False |
| ): |
| r_ij = coords.index_select(-2, index[0]) - coords.index_select(-2, index[1]) |
| r_kj = coords.index_select(-2, index[2]) - coords.index_select(-2, index[1]) |
| r_kl = coords.index_select(-2, index[2]) - coords.index_select(-2, index[3]) |
|
|
| n_ijk = torch.cross(r_ij, r_kj, dim=-1) |
| n_jkl = torch.cross(r_kj, r_kl, dim=-1) |
|
|
| r_kj_norm = torch.linalg.norm(r_kj, dim=-1) |
| n_ijk_norm = torch.linalg.norm(n_ijk, dim=-1) |
| n_jkl_norm = torch.linalg.norm(n_jkl, dim=-1) |
|
|
| sign_phi = torch.sign( |
| r_kj.unsqueeze(-2) @ torch.cross(n_ijk, n_jkl, dim=-1).unsqueeze(-1) |
| ).squeeze(-1, -2) |
| phi = sign_phi * torch.arccos( |
| torch.clamp( |
| (n_ijk.unsqueeze(-2) @ n_jkl.unsqueeze(-1)).squeeze(-1, -2) |
| / (n_ijk_norm * n_jkl_norm), |
| -1 + 1e-8, |
| 1 - 1e-8, |
| ) |
| ) |
|
|
| if not compute_gradient: |
| return phi |
|
|
| a = ( |
| (r_ij.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2) |
| ).unsqueeze(-1) |
| b = ( |
| (r_kl.unsqueeze(-2) @ r_kj.unsqueeze(-1)).squeeze(-1, -2) / (r_kj_norm**2) |
| ).unsqueeze(-1) |
|
|
| grad_i = n_ijk * (r_kj_norm / n_ijk_norm**2).unsqueeze(-1) |
| grad_l = -1 * n_jkl * (r_kj_norm / n_jkl_norm**2).unsqueeze(-1) |
| grad_j = (a - 1) * grad_i - b * grad_l |
| grad_k = (b - 1) * grad_l - a * grad_i |
| grad = torch.stack((grad_i, grad_j, grad_k, grad_l), dim=1) |
| return phi, grad |
|
|
|
|
| class AbsDihedralPotential(DihedralPotential): |
| def compute_variable( |
| self, coords, index, ref_coords=None, ref_mask=None, compute_gradient=False |
| ): |
| if not compute_gradient: |
| phi = super().compute_variable( |
| coords, index, compute_gradient=compute_gradient |
| ) |
| phi = torch.abs(phi) |
| return phi |
|
|
| phi, grad = super().compute_variable( |
| coords, index, compute_gradient=compute_gradient |
| ) |
| grad[(phi < 0)[..., None, :, None].expand_as(grad)] *= -1 |
| phi = torch.abs(phi) |
|
|
| return phi, grad |
|
|
|
|
| class PoseBustersPotential(FlatBottomPotential, DistancePotential): |
| def compute_args(self, feats, parameters): |
| pair_index = feats["rdkit_bounds_index"][0] |
| lower_bounds = feats["rdkit_lower_bounds"][0].clone() |
| upper_bounds = feats["rdkit_upper_bounds"][0].clone() |
| bond_mask = feats["rdkit_bounds_bond_mask"][0] |
| angle_mask = feats["rdkit_bounds_angle_mask"][0] |
|
|
| lower_bounds[bond_mask * ~angle_mask] *= 1.0 - parameters["bond_buffer"] |
| upper_bounds[bond_mask * ~angle_mask] *= 1.0 + parameters["bond_buffer"] |
| lower_bounds[~bond_mask * angle_mask] *= 1.0 - parameters["angle_buffer"] |
| upper_bounds[~bond_mask * angle_mask] *= 1.0 + parameters["angle_buffer"] |
| lower_bounds[bond_mask * angle_mask] *= 1.0 - min( |
| parameters["bond_buffer"], parameters["angle_buffer"] |
| ) |
| upper_bounds[bond_mask * angle_mask] *= 1.0 + min( |
| parameters["bond_buffer"], parameters["angle_buffer"] |
| ) |
| lower_bounds[~bond_mask * ~angle_mask] *= 1.0 - parameters["clash_buffer"] |
| upper_bounds[~bond_mask * ~angle_mask] = float("inf") |
|
|
| vdw_radii = torch.zeros( |
| const.num_elements, dtype=torch.float32, device=pair_index.device |
| ) |
| vdw_radii[1:119] = torch.tensor( |
| const.vdw_radii, dtype=torch.float32, device=pair_index.device |
| ) |
| atom_vdw_radii = ( |
| feats["ref_element"].float() @ vdw_radii.unsqueeze(-1) |
| ).squeeze(-1)[0] |
| bond_cutoffs = 0.35 + atom_vdw_radii[pair_index].mean(dim=0) |
| lower_bounds[~bond_mask] = torch.max(lower_bounds[~bond_mask], bond_cutoffs[~bond_mask]) |
| upper_bounds[bond_mask] = torch.min(upper_bounds[bond_mask], bond_cutoffs[bond_mask]) |
|
|
| k = torch.ones_like(lower_bounds) |
|
|
| return pair_index, (k, lower_bounds, upper_bounds), None, None, None |
|
|
|
|
| class ConnectionsPotential(FlatBottomPotential, DistancePotential): |
| def compute_args(self, feats, parameters): |
| pair_index = feats["connected_atom_index"][0] |
| lower_bounds = None |
| upper_bounds = torch.full( |
| (pair_index.shape[1],), parameters["buffer"], device=pair_index.device |
| ) |
| k = torch.ones_like(upper_bounds) |
|
|
| return pair_index, (k, lower_bounds, upper_bounds), None, None, None |
|
|
|
|
| class VDWOverlapPotential(FlatBottomPotential, DistancePotential): |
| def compute_args(self, feats, parameters): |
| atom_chain_id = ( |
| torch.bmm( |
| feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float() |
| ) |
| .squeeze(-1) |
| .long() |
| )[0] |
| atom_pad_mask = feats["atom_pad_mask"][0].bool() |
| chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask]) |
| single_ion_mask = (chain_sizes > 1)[atom_chain_id] |
|
|
| vdw_radii = torch.zeros( |
| const.num_elements, dtype=torch.float32, device=atom_chain_id.device |
| ) |
| vdw_radii[1:119] = torch.tensor( |
| const.vdw_radii, dtype=torch.float32, device=atom_chain_id.device |
| ) |
| atom_vdw_radii = ( |
| feats["ref_element"].float() @ vdw_radii.unsqueeze(-1) |
| ).squeeze(-1)[0] |
|
|
| pair_index = torch.triu_indices( |
| atom_chain_id.shape[0], |
| atom_chain_id.shape[0], |
| 1, |
| device=atom_chain_id.device, |
| ) |
|
|
| pair_pad_mask = atom_pad_mask[pair_index].all(dim=0) |
| pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]] |
|
|
| num_chains = atom_chain_id.max() + 1 |
| connected_chain_index = feats["connected_chain_index"][0] |
| connected_chain_matrix = torch.eye( |
| num_chains, device=atom_chain_id.device, dtype=torch.bool |
| ) |
| connected_chain_matrix[connected_chain_index[0], connected_chain_index[1]] = ( |
| True |
| ) |
| connected_chain_matrix[connected_chain_index[1], connected_chain_index[0]] = ( |
| True |
| ) |
| connected_chain_mask = connected_chain_matrix[ |
| atom_chain_id[pair_index[0]], atom_chain_id[pair_index[1]] |
| ] |
|
|
| pair_index = pair_index[ |
| :, pair_pad_mask * pair_ion_mask * ~connected_chain_mask |
| ] |
|
|
| lower_bounds = atom_vdw_radii[pair_index].sum(dim=0) * ( |
| 1.0 - parameters["buffer"] |
| ) |
| upper_bounds = None |
| k = torch.ones_like(lower_bounds) |
|
|
| return pair_index, (k, lower_bounds, upper_bounds), None, None, None |
|
|
|
|
| class SymmetricChainCOMPotential(FlatBottomPotential, DistancePotential): |
| def compute_args(self, feats, parameters): |
| atom_chain_id = ( |
| torch.bmm( |
| feats["atom_to_token"].float(), feats["asym_id"].unsqueeze(-1).float() |
| ) |
| .squeeze(-1) |
| .long() |
| )[0] |
| atom_pad_mask = feats["atom_pad_mask"][0].bool() |
| chain_sizes = torch.bincount(atom_chain_id[atom_pad_mask]) |
| single_ion_mask = chain_sizes > 1 |
|
|
| pair_index = feats["symmetric_chain_index"][0] |
| pair_ion_mask = single_ion_mask[pair_index[0]] * single_ion_mask[pair_index[1]] |
| pair_index = pair_index[:, pair_ion_mask] |
| lower_bounds = torch.full( |
| (pair_index.shape[1],), |
| parameters["buffer"], |
| dtype=torch.float32, |
| device=pair_index.device, |
| ) |
| upper_bounds = None |
| k = torch.ones_like(lower_bounds) |
|
|
| return ( |
| pair_index, |
| (k, lower_bounds, upper_bounds), |
| (atom_chain_id, atom_pad_mask), |
| None, |
| None, |
| ) |
|
|
|
|
| class StereoBondPotential(FlatBottomPotential, AbsDihedralPotential): |
| def compute_args(self, feats, parameters): |
| stereo_bond_index = feats["stereo_bond_index"][0] |
| stereo_bond_orientations = feats["stereo_bond_orientations"][0].bool() |
|
|
| lower_bounds = torch.zeros( |
| stereo_bond_orientations.shape, device=stereo_bond_orientations.device |
| ) |
| upper_bounds = torch.zeros( |
| stereo_bond_orientations.shape, device=stereo_bond_orientations.device |
| ) |
| lower_bounds[stereo_bond_orientations] = torch.pi - parameters["buffer"] |
| upper_bounds[stereo_bond_orientations] = float("inf") |
| lower_bounds[~stereo_bond_orientations] = float("-inf") |
| upper_bounds[~stereo_bond_orientations] = parameters["buffer"] |
|
|
| k = torch.ones_like(lower_bounds) |
|
|
| return stereo_bond_index, (k, lower_bounds, upper_bounds), None, None, None |
|
|
|
|
| class ChiralAtomPotential(FlatBottomPotential, DihedralPotential): |
| def compute_args(self, feats, parameters): |
| chiral_atom_index = feats["chiral_atom_index"][0] |
| chiral_atom_orientations = feats["chiral_atom_orientations"][0].bool() |
|
|
| lower_bounds = torch.zeros( |
| chiral_atom_orientations.shape, device=chiral_atom_orientations.device |
| ) |
| upper_bounds = torch.zeros( |
| chiral_atom_orientations.shape, device=chiral_atom_orientations.device |
| ) |
| lower_bounds[chiral_atom_orientations] = parameters["buffer"] |
| upper_bounds[chiral_atom_orientations] = float("inf") |
| upper_bounds[~chiral_atom_orientations] = -1 * parameters["buffer"] |
| lower_bounds[~chiral_atom_orientations] = float("-inf") |
|
|
| k = torch.ones_like(lower_bounds) |
| return chiral_atom_index, (k, lower_bounds, upper_bounds), None, None, None |
|
|
|
|
| class PlanarBondPotential(FlatBottomPotential, AbsDihedralPotential): |
| def compute_args(self, feats, parameters): |
| double_bond_index = feats["planar_bond_index"][0].T |
| double_bond_improper_index = torch.tensor( |
| [ |
| [1, 2, 3, 0], |
| [4, 5, 0, 3], |
| ], |
| device=double_bond_index.device, |
| ).T |
| improper_index = ( |
| double_bond_index[:, double_bond_improper_index] |
| .swapaxes(0, 1) |
| .flatten(start_dim=1) |
| ) |
| lower_bounds = None |
| upper_bounds = torch.full( |
| (improper_index.shape[1],), |
| parameters["buffer"], |
| device=improper_index.device, |
| ) |
| k = torch.ones_like(upper_bounds) |
|
|
| return improper_index, (k, lower_bounds, upper_bounds), None, None, None |
|
|
|
|
| class TemplateReferencePotential(FlatBottomPotential, ReferencePotential): |
| def compute_args(self, feats, parameters): |
| if "template_mask_cb" not in feats or "template_force" not in feats: |
| return torch.empty([1, 0]), None, None, None, None |
|
|
| template_mask = feats["template_mask_cb"][feats["template_force"]] |
| if template_mask.shape[0] == 0: |
| return torch.empty([1, 0]), None, None, None, None |
|
|
| ref_coords = feats["template_cb"][feats["template_force"]].clone() |
| ref_mask = feats["template_mask_cb"][feats["template_force"]].clone() |
| ref_atom_index = ( |
| torch.bmm( |
| feats["token_to_rep_atom"].float(), |
| torch.arange( |
| feats["atom_pad_mask"].shape[1], |
| device=feats["atom_pad_mask"].device, |
| dtype=torch.float32, |
| )[None, :, None], |
| ) |
| .squeeze(-1) |
| .long() |
| )[0] |
| ref_token_index = ( |
| torch.bmm( |
| feats["atom_to_token"].float(), |
| feats["token_index"].unsqueeze(-1).float(), |
| ) |
| .squeeze(-1) |
| .long() |
| )[0] |
|
|
| index = torch.arange( |
| template_mask.shape[-1], dtype=torch.long, device=template_mask.device |
| )[None] |
| upper_bounds = torch.full( |
| template_mask.shape, float("inf"), device=index.device, dtype=torch.float32 |
| ) |
| ref_idxs = torch.argwhere(template_mask).T |
| upper_bounds[ref_idxs.unbind()] = feats["template_force_threshold"][ |
| feats["template_force"] |
| ][ref_idxs[0]] |
|
|
| lower_bounds = None |
| k = torch.ones_like(upper_bounds) |
| return ( |
| index, |
| (k, lower_bounds, upper_bounds), |
| None, |
| (ref_coords, ref_mask, ref_atom_index, ref_token_index), |
| None, |
| ) |
|
|
|
|
| class ContactPotentital(FlatBottomPotential, DistancePotential): |
| def compute_args(self, feats, parameters): |
| index = feats["contact_pair_index"][0] |
| union_index = feats["contact_union_index"][0] |
| negation_mask = feats["contact_negation_mask"][0] |
| lower_bounds = None |
| upper_bounds = feats["contact_thresholds"][0].clone() |
| k = torch.ones_like(upper_bounds) |
| return ( |
| index, |
| (k, lower_bounds, upper_bounds), |
| None, |
| None, |
| (negation_mask, union_index), |
| ) |
|
|
|
|
| def get_potentials(steering_args, boltz2=False): |
| potentials = [] |
| if steering_args["fk_steering"] or steering_args["physical_guidance_update"]: |
| potentials.extend( |
| [ |
| SymmetricChainCOMPotential( |
| parameters={ |
| "guidance_interval": 4, |
| "guidance_weight": 0.5 |
| if steering_args["physical_guidance_update"] |
| else 0.0, |
| "resampling_weight": 0.5, |
| "buffer": ExponentialInterpolation( |
| start=1.0, end=5.0, alpha=-2.0 |
| ), |
| } |
| ), |
| VDWOverlapPotential( |
| parameters={ |
| "guidance_interval": 5, |
| "guidance_weight": ( |
| PiecewiseStepFunction(thresholds=[0.4], values=[0.125, 0.0]) |
| if steering_args["physical_guidance_update"] |
| else 0.0 |
| ), |
| "resampling_weight": PiecewiseStepFunction( |
| thresholds=[0.6], values=[0.01, 0.0] |
| ), |
| "buffer": 0.225, |
| } |
| ), |
| ConnectionsPotential( |
| parameters={ |
| "guidance_interval": 1, |
| "guidance_weight": 0.15 |
| if steering_args["physical_guidance_update"] |
| else 0.0, |
| "resampling_weight": 1.0, |
| "buffer": 2.0, |
| } |
| ), |
| PoseBustersPotential( |
| parameters={ |
| "guidance_interval": 1, |
| "guidance_weight": 0.01 |
| if steering_args["physical_guidance_update"] |
| else 0.0, |
| "resampling_weight": 0.1, |
| "bond_buffer": 0.125, |
| "angle_buffer": 0.125, |
| "clash_buffer": 0.10, |
| } |
| ), |
| ChiralAtomPotential( |
| parameters={ |
| "guidance_interval": 1, |
| "guidance_weight": 0.1 |
| if steering_args["physical_guidance_update"] |
| else 0.0, |
| "resampling_weight": 1.0, |
| "buffer": 0.52360, |
| } |
| ), |
| StereoBondPotential( |
| parameters={ |
| "guidance_interval": 1, |
| "guidance_weight": 0.05 |
| if steering_args["physical_guidance_update"] |
| else 0.0, |
| "resampling_weight": 1.0, |
| "buffer": 0.52360, |
| } |
| ), |
| PlanarBondPotential( |
| parameters={ |
| "guidance_interval": 1, |
| "guidance_weight": 0.05 |
| if steering_args["physical_guidance_update"] |
| else 0.0, |
| "resampling_weight": 1.0, |
| "buffer": 0.26180, |
| } |
| ), |
| ] |
| ) |
| if boltz2 and ( |
| steering_args["fk_steering"] or steering_args["contact_guidance_update"] |
| ): |
| potentials.extend( |
| [ |
| ContactPotentital( |
| parameters={ |
| "guidance_interval": 4, |
| "guidance_weight": ( |
| PiecewiseStepFunction( |
| thresholds=[0.25, 0.75], values=[0.0, 0.5, 1.0] |
| ) |
| if steering_args["contact_guidance_update"] |
| else 0.0 |
| ), |
| "resampling_weight": 1.0, |
| "union_lambda": ExponentialInterpolation( |
| start=8.0, end=0.0, alpha=-2.0 |
| ), |
| } |
| ), |
| TemplateReferencePotential( |
| parameters={ |
| "guidance_interval": 2, |
| "guidance_weight": 0.1 |
| if steering_args["contact_guidance_update"] |
| else 0.0, |
| "resampling_weight": 1.0, |
| } |
| ), |
| ] |
| ) |
| return potentials |
|
|