DDHpose / common /quaternion.py
Andyen512
Add model checkpoints and configs
1e45055
raw
history blame contribute delete
808 Bytes
import torch
def qrot(q, v):
"""
Rotate vector(s) v about the rotation described by quaternion(s) q.
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
where * denotes any number of dimensions.
Returns a tensor of shape (*, 3).
"""
assert q.shape[-1] == 4
assert v.shape[-1] == 3
assert q.shape[:-1] == v.shape[:-1]
qvec = q[..., 1:]
uv = torch.cross(qvec, v, dim=len(q.shape)-1)
uuv = torch.cross(qvec, uv, dim=len(q.shape)-1)
return (v + 2 * (q[..., :1] * uv + uuv))
def qinverse(q, inplace=False):
# We assume the quaternion to be normalized
if inplace:
q[..., 1:] *= -1
return q
else:
w = q[..., :1]
xyz = q[..., 1:]
return torch.cat((w, -xyz), dim=len(q.shape)-1)