File size: 1,436 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torch import Tensor

import torch

class Functional:
    @staticmethod
    def slerp(low:Tensor,
              high:Tensor,
              val:float = 0.5
              ):
        '''
        Spherical Linear Interpolation (Slerp)
        Slerp(q_0,q_1;t)    = q_0(q_0^-1 q_1)^t
                            = ( sin(1-t) theta ) / sin(theta) * q_0 * sin(t * theta)/sin(theta) * q_1
                            where dot_product(q_0,q_1) = cos(theta)
        
        theta = np.arccos(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)))
        so = np.sin(theta)
        return np.sin((1.0-val)*theta) / so * low + np.sin(val*theta)/so * high
        '''
        assert tuple(low.shape) == tuple(high.shape), f'low shape({low.shape}) must be same as high shape({high.shape})'
        feature_shape:tuple = tuple(low.shape)
        # Normalize the vectors to get the directions and angles
        low_1d:Tensor = low.reshape(feature_shape[0],-1)
        high_1d:Tensor = high.reshape(feature_shape[0],-1)
        low_norm = low_1d/torch.norm(low_1d, dim=1, keepdim=True)
        high_norm = high_1d/torch.norm(high_1d, dim=1, keepdim=True)

        dot_product = (low_norm*high_norm).sum(dim = 1)
        theta = torch.acos(dot_product)
        so = torch.sin(theta)
        res = (torch.sin((1.0-val)*theta)/so).unsqueeze(1)*low_1d + (torch.sin(val*theta)/so).unsqueeze(1) * high_1d
        return res.reshape(feature_shape)