| import numpy as np | |
| import torch | |
| from typing import Union, List | |
| class linear: | |
| def __init__(self): | |
| pass | |
| def execute( | |
| self, | |
| t: Union[float, List[float]], | |
| v0: Union[List[torch.Tensor], torch.Tensor], | |
| v1: Union[List[torch.Tensor], torch.Tensor], | |
| DOT_THRESHOLD: float = 0.9995, | |
| eps: float = 1e-8, | |
| densities = None, | |
| ): | |
| if type(v0) is list: | |
| v0 = v0[0] | |
| if type(t) is list: | |
| t = t[0] | |
| if type(v1) is list: | |
| v1 = v1[0] | |
| return t * v1 + (1.0 - t) * v0 |