GlandVergil's picture
Upload 686 files
3cdaa7d verified
raw
history blame
8.95 kB
"""Helper class for handle symmetric assemblies."""
from pyrsistent import v
from scipy.spatial.transform import Rotation
import functools as fn
import torch
import string
import logging
import numpy as np
import pathlib
format_rots = lambda r: torch.tensor(r).float()
T3_ROTATIONS = [
torch.Tensor([
[ 1., 0., 0.],
[ 0., 1., 0.],
[ 0., 0., 1.]]).float(),
torch.Tensor([
[-1., -0., 0.],
[-0., 1., 0.],
[-0., 0., -1.]]).float(),
torch.Tensor([
[-1., 0., 0.],
[ 0., -1., 0.],
[ 0., 0., 1.]]).float(),
torch.Tensor([
[ 1., 0., 0.],
[ 0., -1., 0.],
[ 0., 0., -1.]]).float(),
]
saved_symmetries = ['tetrahedral', 'octahedral', 'icosahedral']
class SymGen:
def __init__(self, global_sym, recenter, radius, model_only_neighbors=False):
self._log = logging.getLogger(__name__)
self._recenter = recenter
self._radius = radius
if global_sym.lower().startswith('c'):
# Cyclic symmetry
if not global_sym[1:].isdigit():
raise ValueError(f'Invalid cyclic symmetry {global_sym}')
self._log.info(
f'Initializing cyclic symmetry order {global_sym[1:]}.')
self._init_cyclic(int(global_sym[1:]))
self.apply_symmetry = self._apply_cyclic
elif global_sym.lower().startswith('d'):
# Dihedral symmetry
if not global_sym[1:].isdigit():
raise ValueError(f'Invalid dihedral symmetry {global_sym}')
self._log.info(
f'Initializing dihedral symmetry order {global_sym[1:]}.')
self._init_dihedral(int(global_sym[1:]))
# Applied the same way as cyclic symmetry
self.apply_symmetry = self._apply_cyclic
elif global_sym.lower() == 't3':
# Tetrahedral (T3) symmetry
self._log.info('Initializing T3 symmetry order.')
self.sym_rots = T3_ROTATIONS
self.order = 4
# Applied the same way as cyclic symmetry
self.apply_symmetry = self._apply_cyclic
elif global_sym == 'octahedral':
# Octahedral symmetry
self._log.info(
'Initializing octahedral symmetry.')
self._init_octahedral()
self.apply_symmetry = self._apply_octahedral
elif global_sym.lower() in saved_symmetries:
# Using a saved symmetry
self._log.info('Initializing %s symmetry order.'%global_sym)
self._init_from_symrots_file(global_sym)
# Applied the same way as cyclic symmetry
self.apply_symmetry = self._apply_cyclic
else:
raise ValueError(f'Unrecognized symmetry {global_sym}')
self.res_idx_procesing = fn.partial(
self._lin_chainbreaks, num_breaks=self.order)
#####################
## Cyclic symmetry ##
#####################
def _init_cyclic(self, order):
sym_rots = []
for i in range(order):
deg = i * 360.0 / order
r = Rotation.from_euler('z', deg, degrees=True)
sym_rots.append(format_rots(r.as_matrix()))
self.sym_rots = sym_rots
self.order = order
def _apply_cyclic(self, coords_in, seq_in):
coords_out = torch.clone(coords_in)
seq_out = torch.clone(seq_in)
if seq_out.shape[0] % self.order != 0:
raise ValueError(
f'Sequence length must be divisble by {self.order}')
subunit_len = seq_out.shape[0] // self.order
for i in range(self.order):
start_i = subunit_len * i
end_i = subunit_len * (i+1)
coords_out[start_i:end_i] = torch.einsum(
'bnj,kj->bnk', coords_out[:subunit_len], self.sym_rots[i])
seq_out[start_i:end_i] = seq_out[:subunit_len]
return coords_out, seq_out
def _lin_chainbreaks(self, num_breaks, res_idx, offset=None):
assert res_idx.ndim == 2
res_idx = torch.clone(res_idx)
subunit_len = res_idx.shape[-1] // num_breaks
chain_delimiters = []
if offset is None:
offset = res_idx.shape[-1]
for i in range(num_breaks):
start_i = subunit_len * i
end_i = subunit_len * (i+1)
chain_labels = list(string.ascii_uppercase) + [str(i+j) for i in
string.ascii_uppercase for j in string.ascii_uppercase]
chain_delimiters.extend(
[chain_labels[i] for _ in range(subunit_len)]
)
res_idx[:, start_i:end_i] = res_idx[:, start_i:end_i] + offset * (i+1)
return res_idx, chain_delimiters
#######################
## Dihedral symmetry ##
#######################
def _init_dihedral(self, order):
sym_rots = []
flip = Rotation.from_euler('x', 180, degrees=True).as_matrix()
for i in range(order):
deg = i * 360.0 / order
rot = Rotation.from_euler('z', deg, degrees=True).as_matrix()
sym_rots.append(format_rots(rot))
rot2 = flip @ rot
sym_rots.append(format_rots(rot2))
self.sym_rots = sym_rots
self.order = order * 2
#########################
## Octahedral symmetry ##
#########################
def _init_octahedral(self):
sym_rots = np.load(f"{pathlib.Path(__file__).parent.resolve()}/sym_rots.npz")
self.sym_rots = [
torch.tensor(v_i, dtype=torch.float32)
for v_i in sym_rots['octahedral']
]
self.order = len(self.sym_rots)
def _apply_octahedral(self, coords_in, seq_in):
coords_out = torch.clone(coords_in)
seq_out = torch.clone(seq_in)
if seq_out.shape[0] % self.order != 0:
raise ValueError(
f'Sequence length must be divisble by {self.order}')
subunit_len = seq_out.shape[0] // self.order
base_axis = torch.tensor([self._radius, 0., 0.])[None]
for i in range(self.order):
start_i = subunit_len * i
end_i = subunit_len * (i+1)
subunit_chain = torch.einsum(
'bnj,kj->bnk', coords_in[:subunit_len], self.sym_rots[i])
if self._recenter:
center = torch.mean(subunit_chain[:, 1, :], axis=0)
subunit_chain -= center[None, None, :]
rotated_axis = torch.einsum(
'nj,kj->nk', base_axis, self.sym_rots[i])
subunit_chain += rotated_axis[:, None, :]
coords_out[start_i:end_i] = subunit_chain
seq_out[start_i:end_i] = seq_out[:subunit_len]
return coords_out, seq_out
#######################
## symmetry from file #
#######################
def _init_from_symrots_file(self, name):
""" _init_from_symrots_file initializes using
./inference/sym_rots.npz
Args:
name: name of symmetry (of tetrahedral, octahedral, icosahedral)
sets self.sym_rots to be a list of torch.tensor of shape [3, 3]
"""
assert name in saved_symmetries, name + " not in " + str(saved_symmetries)
# Load in list of rotation matrices for `name`
fn = f"{pathlib.Path(__file__).parent.resolve()}/sym_rots.npz"
obj = np.load(fn)
symms = None
for k, v in obj.items():
if str(k) == name: symms = v
assert symms is not None, "%s not found in %s"%(name, fn)
self.sym_rots = [torch.tensor(v_i, dtype=torch.float32) for v_i in symms]
self.order = len(self.sym_rots)
# Return if identity is the first rotation
if not np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0):
# Move identity to be the first rotation
for i, rot in enumerate(self.sym_rots):
if np.isclose(((rot-np.eye(3))**2).sum(), 0):
self.sym_rots = [self.sym_rots.pop(i)] + self.sym_rots
assert len(self.sym_rots) == self.order
assert np.isclose(((self.sym_rots[0]-np.eye(3))**2).sum(), 0)
def close_neighbors(self):
"""close_neighbors finds the rotations within self.sym_rots that
correspond to close neighbors.
Returns:
list of rotation matrices corresponding to the identity and close neighbors
"""
# set of small rotation angle rotations
rel_rot = lambda M: np.linalg.norm(Rotation.from_matrix(M).as_rotvec())
rel_rots = [(i+1, rel_rot(M)) for i, M in enumerate(self.sym_rots[1:])]
min_rot = min(rel_rot_val[1] for rel_rot_val in rel_rots)
close_rots = [np.eye(3)] + [
self.sym_rots[i] for i, rel_rot_val in rel_rots if
np.isclose(rel_rot_val, min_rot)
]
return close_rots