| |
|
|
| from functools import partial |
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.nn import ( |
| Linear, |
| Module, |
| ) |
| from torch.types import Device |
|
|
| LinearNoBias = partial(Linear, bias=False) |
|
|
|
|
| def exists(v): |
| return v is not None |
|
|
|
|
| def default(v, d): |
| return v if exists(v) else d |
|
|
|
|
| def log(t, eps=1e-20): |
| return torch.log(t.clamp(min=eps)) |
|
|
|
|
| class SwiGLU(Module): |
| def forward( |
| self, |
| x, |
| ): |
| x, gates = x.chunk(2, dim=-1) |
| return F.silu(gates) * x |
|
|
|
|
| def center(atom_coords, atom_mask): |
| atom_mean = torch.sum( |
| atom_coords * atom_mask[:, :, None], dim=1, keepdim=True |
| ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True) |
| atom_coords = atom_coords - atom_mean |
| return atom_coords |
|
|
|
|
| def compute_random_augmentation( |
| multiplicity, s_trans=1.0, device=None, dtype=torch.float32 |
| ): |
| R = random_rotations(multiplicity, dtype=dtype, device=device) |
| random_trans = ( |
| torch.randn((multiplicity, 1, 3), dtype=dtype, device=device) * s_trans |
| ) |
| return R, random_trans |
|
|
|
|
| def randomly_rotate(coords, return_second_coords=False, second_coords=None): |
| R = random_rotations(len(coords), coords.dtype, coords.device) |
|
|
| if return_second_coords: |
| return torch.einsum("bmd,bds->bms", coords, R), torch.einsum( |
| "bmd,bds->bms", second_coords, R |
| ) if second_coords is not None else None |
|
|
| return torch.einsum("bmd,bds->bms", coords, R) |
|
|
|
|
| def center_random_augmentation( |
| atom_coords, |
| atom_mask, |
| s_trans=1.0, |
| augmentation=True, |
| centering=True, |
| return_second_coords=False, |
| second_coords=None, |
| ): |
| """Algorithm 19""" |
| if centering: |
| atom_mean = torch.sum( |
| atom_coords * atom_mask[:, :, None], dim=1, keepdim=True |
| ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True) |
| atom_coords = atom_coords - atom_mean |
|
|
| if second_coords is not None: |
| |
| second_coords = second_coords - atom_mean |
|
|
| if augmentation: |
| atom_coords, second_coords = randomly_rotate( |
| atom_coords, return_second_coords=True, second_coords=second_coords |
| ) |
| random_trans = torch.randn_like(atom_coords[:, 0:1, :]) * s_trans |
| atom_coords = atom_coords + random_trans |
|
|
| if second_coords is not None: |
| second_coords = second_coords + random_trans |
|
|
| if return_second_coords: |
| return atom_coords, second_coords |
|
|
| return atom_coords |
|
|
|
|
| class ExponentialMovingAverage: |
| """from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py, Apache-2.0 license |
| Maintains (exponential) moving average of a set of parameters.""" |
|
|
| def __init__(self, parameters, decay, use_num_updates=True): |
| """ |
| Args: |
| parameters: Iterable of `torch.nn.Parameter`; usually the result of |
| `model.parameters()`. |
| decay: The exponential decay. |
| use_num_updates: Whether to use number of updates when computing |
| averages. |
| """ |
| if decay < 0.0 or decay > 1.0: |
| raise ValueError("Decay must be between 0 and 1") |
| self.decay = decay |
| self.num_updates = 0 if use_num_updates else None |
| self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] |
| self.collected_params = [] |
|
|
| def update(self, parameters): |
| """ |
| Update currently maintained parameters. |
| Call this every time the parameters are updated, such as the result of |
| the `optimizer.step()` call. |
| Args: |
| parameters: Iterable of `torch.nn.Parameter`; usually the same set of |
| parameters used to initialize this object. |
| """ |
| decay = self.decay |
| if self.num_updates is not None: |
| self.num_updates += 1 |
| decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) |
| one_minus_decay = 1.0 - decay |
| with torch.no_grad(): |
| parameters = [p for p in parameters if p.requires_grad] |
| for s_param, param in zip(self.shadow_params, parameters): |
| s_param.sub_(one_minus_decay * (s_param - param)) |
|
|
| def compatible(self, parameters): |
| if len(self.shadow_params) != len(parameters): |
| print( |
| f"Model has {len(self.shadow_params)} parameter tensors, the incoming ema {len(parameters)}" |
| ) |
| return False |
|
|
| for s_param, param in zip(self.shadow_params, parameters): |
| if param.data.shape != s_param.data.shape: |
| print( |
| f"Model has parameter tensor of shape {s_param.data.shape} , the incoming ema {param.data.shape}" |
| ) |
| return False |
| return True |
|
|
| def copy_to(self, parameters): |
| """ |
| Copy current parameters into given collection of parameters. |
| Args: |
| parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| updated with the stored moving averages. |
| """ |
| parameters = [p for p in parameters if p.requires_grad] |
| for s_param, param in zip(self.shadow_params, parameters): |
| if param.requires_grad: |
| param.data.copy_(s_param.data) |
|
|
| def store(self, parameters): |
| """ |
| Save the current parameters for restoring later. |
| Args: |
| parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| temporarily stored. |
| """ |
| self.collected_params = [param.clone() for param in parameters] |
|
|
| def restore(self, parameters): |
| """ |
| Restore the parameters stored with the `store` method. |
| Useful to validate the model with EMA parameters without affecting the |
| original optimization process. Store the parameters before the |
| `copy_to` method. After validation (or model saving), use this to |
| restore the former parameters. |
| Args: |
| parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| updated with the stored parameters. |
| """ |
| for c_param, param in zip(self.collected_params, parameters): |
| param.data.copy_(c_param.data) |
|
|
| def state_dict(self): |
| return dict( |
| decay=self.decay, |
| num_updates=self.num_updates, |
| shadow_params=self.shadow_params, |
| ) |
|
|
| def load_state_dict(self, state_dict, device): |
| self.decay = state_dict["decay"] |
| self.num_updates = state_dict["num_updates"] |
| self.shadow_params = [ |
| tensor.to(device) for tensor in state_dict["shadow_params"] |
| ] |
|
|
| def to(self, device): |
| self.shadow_params = [tensor.to(device) for tensor in self.shadow_params] |
|
|
|
|
| |
|
|
|
|
| def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| """ |
| Return a tensor where each element has the absolute value taken from the, |
| corresponding element of a, with sign taken from the corresponding |
| element of b. This is like the standard copysign floating-point operation, |
| but is not careful about negative 0 and NaN. |
| |
| Args: |
| a: source tensor. |
| b: tensor whose signs will be used, of the same shape as a. |
| |
| Returns: |
| Tensor of the same shape as a with the signs of b. |
| """ |
| signs_differ = (a < 0) != (b < 0) |
| return torch.where(signs_differ, -a, a) |
|
|
|
|
| def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: |
| """ |
| Convert rotations given as quaternions to rotation matrices. |
| |
| Args: |
| quaternions: quaternions with real part first, |
| as tensor of shape (..., 4). |
| |
| Returns: |
| Rotation matrices as tensor of shape (..., 3, 3). |
| """ |
| r, i, j, k = torch.unbind(quaternions, -1) |
| |
| two_s = 2.0 / (quaternions * quaternions).sum(-1) |
|
|
| o = torch.stack( |
| ( |
| 1 - two_s * (j * j + k * k), |
| two_s * (i * j - k * r), |
| two_s * (i * k + j * r), |
| two_s * (i * j + k * r), |
| 1 - two_s * (i * i + k * k), |
| two_s * (j * k - i * r), |
| two_s * (i * k - j * r), |
| two_s * (j * k + i * r), |
| 1 - two_s * (i * i + j * j), |
| ), |
| -1, |
| ) |
| return o.reshape(quaternions.shape[:-1] + (3, 3)) |
|
|
|
|
| def random_quaternions( |
| n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None |
| ) -> torch.Tensor: |
| """ |
| Generate random quaternions representing rotations, |
| i.e. versors with nonnegative real part. |
| |
| Args: |
| n: Number of quaternions in a batch to return. |
| dtype: Type to return. |
| device: Desired device of returned tensor. Default: |
| uses the current device for the default tensor type. |
| |
| Returns: |
| Quaternions as tensor of shape (N, 4). |
| """ |
| if isinstance(device, str): |
| device = torch.device(device) |
| o = torch.randn((n, 4), dtype=dtype, device=device) |
| s = (o * o).sum(1) |
| o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] |
| return o |
|
|
|
|
| def random_rotations( |
| n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None |
| ) -> torch.Tensor: |
| """ |
| Generate random rotations as 3x3 rotation matrices. |
| |
| Args: |
| n: Number of rotation matrices in a batch to return. |
| dtype: Type to return. |
| device: Device of returned tensor. Default: if None, |
| uses the current device for the default tensor type. |
| |
| Returns: |
| Rotation matrices as tensor of shape (n, 3, 3). |
| """ |
| quaternions = random_quaternions(n, dtype=dtype, device=device) |
| return quaternion_to_matrix(quaternions) |
|
|