| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from functools import lru_cache |
| from typing import Dict, List |
|
|
| import e3nn.o3 as o3 |
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor |
| from torch.cuda.nvtx import range as nvtx_range |
|
|
| from se3_transformer.runtime.utils import degree_to_dim |
|
|
|
|
| @lru_cache(maxsize=None) |
| def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor: |
| """ Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """ |
| return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0) |
|
|
|
|
| @lru_cache(maxsize=None) |
| def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]: |
| all_cb = [] |
| for d_in in range(max_degree + 1): |
| for d_out in range(max_degree + 1): |
| K_Js = [] |
| for J in range(abs(d_in - d_out), d_in + d_out + 1): |
| K_Js.append(get_clebsch_gordon(J, d_in, d_out, device)) |
| all_cb.append(K_Js) |
| return all_cb |
|
|
|
|
| def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]: |
| all_degrees = list(range(2 * max_degree + 1)) |
| with nvtx_range('spherical harmonics'): |
| sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True) |
| return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1) |
|
|
|
|
| @torch.jit.script |
| def get_basis_script(max_degree: int, |
| use_pad_trick: bool, |
| spherical_harmonics: List[Tensor], |
| clebsch_gordon: List[List[Tensor]], |
| amp: bool) -> Dict[str, Tensor]: |
| """ |
| Compute pairwise bases matrices for degrees up to max_degree |
| :param max_degree: Maximum input or output degree |
| :param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores |
| :param spherical_harmonics: List of computed spherical harmonics |
| :param clebsch_gordon: List of computed CB-coefficients |
| :param amp: When true, return bases in FP16 precision |
| """ |
| basis = {} |
| idx = 0 |
| |
| for d_in in range(max_degree + 1): |
| for d_out in range(max_degree + 1): |
| key = f'{d_in},{d_out}' |
| K_Js = [] |
| for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)): |
| Q_J = clebsch_gordon[idx][freq_idx] |
| K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float())) |
|
|
| basis[key] = torch.stack(K_Js, 2) |
| if amp: |
| basis[key] = basis[key].half() |
| if use_pad_trick: |
| basis[key] = F.pad(basis[key], (0, 1)) |
|
|
| idx += 1 |
|
|
| return basis |
|
|
|
|
| @torch.jit.script |
| def update_basis_with_fused(basis: Dict[str, Tensor], |
| max_degree: int, |
| use_pad_trick: bool, |
| fully_fused: bool) -> Dict[str, Tensor]: |
| """ Update the basis dict with partially and optionally fully fused bases """ |
| num_edges = basis['0,0'].shape[0] |
| device = basis['0,0'].device |
| dtype = basis['0,0'].dtype |
| sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)]) |
|
|
| |
| for d_out in range(max_degree + 1): |
| sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)]) |
| basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick), |
| device=device, dtype=dtype) |
| acc_d, acc_f = 0, 0 |
| for d_in in range(max_degree + 1): |
| basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)), |
| :degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)] |
|
|
| acc_d += degree_to_dim(d_in) |
| acc_f += degree_to_dim(min(d_out, d_in)) |
|
|
| basis[f'out{d_out}_fused'] = basis_fused |
|
|
| |
| for d_in in range(max_degree + 1): |
| sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)]) |
| basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim, |
| device=device, dtype=dtype) |
| acc_d, acc_f = 0, 0 |
| for d_out in range(max_degree + 1): |
| basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \ |
| = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)] |
|
|
| acc_d += degree_to_dim(d_out) |
| acc_f += degree_to_dim(min(d_out, d_in)) |
|
|
| basis[f'in{d_in}_fused'] = basis_fused |
|
|
| if fully_fused: |
| |
| |
| sum_freq = sum([ |
| sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1) |
| ]) |
| basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype) |
|
|
| acc_d, acc_f = 0, 0 |
| for d_out in range(max_degree + 1): |
| b = basis[f'out{d_out}_fused'] |
| basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :, |
| :degree_to_dim(d_out)] |
| acc_f += b.shape[2] |
| acc_d += degree_to_dim(d_out) |
|
|
| basis['fully_fused'] = basis_fused |
|
|
| del basis['0,0'] |
| return basis |
|
|
|
|
| def get_basis(relative_pos: Tensor, |
| max_degree: int = 4, |
| compute_gradients: bool = False, |
| use_pad_trick: bool = False, |
| amp: bool = False) -> Dict[str, Tensor]: |
| with nvtx_range('spherical harmonics'): |
| spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree) |
| with nvtx_range('CB coefficients'): |
| clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device) |
|
|
| with torch.autograd.set_grad_enabled(compute_gradients): |
| with nvtx_range('bases'): |
| basis = get_basis_script(max_degree=max_degree, |
| use_pad_trick=use_pad_trick, |
| spherical_harmonics=spherical_harmonics, |
| clebsch_gordon=clebsch_gordon, |
| amp=amp) |
| return basis |
|
|