| | """SO(3) diffusion methods.""" |
| | import numpy as np |
| | import os |
| | from functools import cached_property |
| | import torch |
| | from scipy.spatial.transform import Rotation |
| | import scipy.linalg |
| |
|
| |
|
| | |
| |
|
| | |
| | def hat(v): |
| | hat_v = torch.zeros([v.shape[0], 3, 3]) |
| | hat_v[:, 0, 1], hat_v[:, 0, 2], hat_v[:, 1, 2] = -v[:, 2], v[:, 1], -v[:, 0] |
| | return hat_v + -hat_v.transpose(2, 1) |
| |
|
| | |
| | def Log(R): return torch.tensor(Rotation.from_matrix(R.numpy()).as_rotvec()) |
| | |
| | |
| | def log(R): return hat(Log(R)) |
| |
|
| | |
| | |
| | def Exp(A): return torch.tensor(Rotation.from_rotvec(A.numpy()).as_matrix()) |
| |
|
| | |
| | def Omega(R): return np.linalg.norm(log(R), axis=[-2, -1])/np.sqrt(2.) |
| |
|
| | L_default = 2000 |
| | def f_igso3(omega, t, L=L_default): |
| | """Truncated sum of IGSO(3) distribution. |
| | |
| | This function approximates the power series in equation 5 of |
| | "DENOISING DIFFUSION PROBABILISTIC MODELS ON SO(3) FOR ROTATIONAL |
| | ALIGNMENT" |
| | Leach et al. 2022 |
| | |
| | This expression diverges from the expression in Leach in that here, sigma = |
| | sqrt(2) * eps, if eps_leach were the scale parameter of the IGSO(3). |
| | |
| | With this reparameterization, IGSO(3) agrees with the Brownian motion on |
| | SO(3) with t=sigma^2 when defined for the canonical inner product on SO3, |
| | <u, v>_SO3 = Trace(u v^T)/2 |
| | |
| | Args: |
| | omega: i.e. the angle of rotation associated with rotation matrix |
| | t: variance parameter of IGSO(3), maps onto time in Brownian motion |
| | L: Truncation level |
| | """ |
| | ls = torch.arange(L)[None] |
| | return ((2*ls + 1) * torch.exp(-ls*(ls+1)*t/2) * |
| | torch.sin(omega[:, None]*(ls+1/2)) / torch.sin(omega[:, None]/2)).sum(dim=-1) |
| |
|
| | def d_logf_d_omega(omega, t, L=L_default): |
| | omega = torch.tensor(omega, requires_grad=True) |
| | log_f = torch.log(f_igso3(omega, t, L)) |
| | return torch.autograd.grad(log_f.sum(), omega)[0].numpy() |
| |
|
| | |
| | def igso3_density(Rt, t, L=L_default): |
| | return f_igso3(torch.tensor(Omega(Rt)), t, L).numpy() |
| |
|
| | def igso3_density_angle(omega, t, L=L_default): |
| | return f_igso3(torch.tensor(omega), t, L).numpy()*(1-np.cos(omega))/np.pi |
| |
|
| | |
| | def igso3_score(R, t, L=L_default): |
| | omega = Omega(R) |
| | unit_vector = np.einsum('Nij,Njk->Nik', R, log(R))/omega[:, None, None] |
| | return unit_vector * d_logf_d_omega(omega, t, L)[:, None, None] |
| |
|
| | def calculate_igso3(*, num_sigma, num_omega, min_sigma, max_sigma): |
| | """calculate_igso3 pre-computes numerical approximations to the IGSO3 cdfs |
| | and score norms and expected squared score norms. |
| | |
| | Args: |
| | num_sigma: number of different sigmas for which to compute igso3 |
| | quantities. |
| | num_omega: number of point in the discretization in the angle of |
| | rotation. |
| | min_sigma, max_sigma: the upper and lower ranges for the angle of |
| | rotation on which to consider the IGSO3 distribution. This cannot |
| | be too low or it will create numerical instability. |
| | """ |
| | |
| | discrete_omega = np.linspace(0, np.pi, num_omega+1)[1:] |
| |
|
| | |
| | |
| | |
| | discrete_sigma = 10 ** np.linspace(np.log10(min_sigma), np.log10(max_sigma), num_sigma + 1)[1:] |
| |
|
| | |
| | |
| | pdf_vals = np.asarray( |
| | [igso3_density_angle(discrete_omega, sigma**2) for sigma in discrete_sigma]) |
| | cdf_vals = np.asarray( |
| | [pdf.cumsum() / num_omega * np.pi for pdf in pdf_vals]) |
| |
|
| | |
| | |
| | score_norm = np.asarray( |
| | [d_logf_d_omega(discrete_omega, sigma**2) for sigma in discrete_sigma]) |
| |
|
| | |
| | exp_score_norms = np.sqrt( |
| | np.sum( |
| | score_norm**2 * pdf_vals, axis=1) / np.sum( |
| | pdf_vals, axis=1)) |
| | return { |
| | 'cdf': cdf_vals, |
| | 'score_norm': score_norm, |
| | 'exp_score_norms': exp_score_norms, |
| | 'discrete_omega': discrete_omega, |
| | 'discrete_sigma': discrete_sigma, |
| | } |
| |
|