Commit ·
475ce7c
1
Parent(s): 6894a19
intial commint
Browse files- basis_gn.py +349 -0
- beast.py +280 -0
- bspline_factory.py +25 -0
- uni_bspline.py +462 -0
- utils.py +166 -0
basis_gn.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
@brief: Basis generators in PyTorch
|
| 3 |
+
"""
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class UniBSplineBasis(torch.nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self,
|
| 12 |
+
num_basis: int = 10,
|
| 13 |
+
degree_p: int = 3,
|
| 14 |
+
dtype: torch.dtype = torch.float32,
|
| 15 |
+
device: torch.device = 'cpu',
|
| 16 |
+
**kwargs):
|
| 17 |
+
"""
|
| 18 |
+
Constructor for basis class
|
| 19 |
+
Args:
|
| 20 |
+
num_basis: number of basis functions
|
| 21 |
+
dtype: torch data type
|
| 22 |
+
device: torch device to run on
|
| 23 |
+
"""
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
# Internal number of basis
|
| 27 |
+
self.num_basis = num_basis
|
| 28 |
+
|
| 29 |
+
self.degree_p = degree_p
|
| 30 |
+
self.init_cond_order = kwargs.get("init_condition_order", 0)
|
| 31 |
+
self.end_cond_order = kwargs.get("end_condition_order", 0)
|
| 32 |
+
|
| 33 |
+
self.num_ctrlp = num_basis + self.init_cond_order + self.end_cond_order
|
| 34 |
+
# number of knots needed, with respect to B-sp degree and number of
|
| 35 |
+
# control points ( num_basis + init_cond_order+end_cond_order)
|
| 36 |
+
num_knots = self.degree_p + 1 + self.num_ctrlp
|
| 37 |
+
num_knots_non_rep_inside_1 = num_knots - 2 * self.degree_p
|
| 38 |
+
# uniform knots vector
|
| 39 |
+
knots_vec = torch.linspace(0, 1, num_knots_non_rep_inside_1,
|
| 40 |
+
dtype=dtype, device=device)
|
| 41 |
+
knots_prev = torch.zeros(self.degree_p, dtype=dtype, device=device)
|
| 42 |
+
knots_pro = torch.ones(self.degree_p, dtype=dtype, device=device)
|
| 43 |
+
knots_vec = torch.cat([knots_prev, knots_vec, knots_pro])
|
| 44 |
+
self.register_buffer("knots_vec", knots_vec, persistent=False)
|
| 45 |
+
|
| 46 |
+
tau = kwargs.get("tau")
|
| 47 |
+
self.register_buffer('tau',
|
| 48 |
+
torch.tensor(tau, dtype=dtype, device=device),
|
| 49 |
+
persistent=False)
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def device(self):
|
| 53 |
+
return self.knots_vec.device
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def dtype(self):
|
| 57 |
+
return self.knots_vec.dtype
|
| 58 |
+
|
| 59 |
+
def time2phase(self, times: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
"""
|
| 61 |
+
scaling time into [0,1] range phase
|
| 62 |
+
:param times:
|
| 63 |
+
:return:
|
| 64 |
+
"""
|
| 65 |
+
# Shape of times:
|
| 66 |
+
# [*add_dim, num_times]
|
| 67 |
+
|
| 68 |
+
# tau = times[..., -1]
|
| 69 |
+
tau = times.reshape(-1)[-1]
|
| 70 |
+
self.tau.copy_(tau)
|
| 71 |
+
phase = torch.clip(times / self.tau[..., None], 0, 1)
|
| 72 |
+
return phase
|
| 73 |
+
|
| 74 |
+
def basis(self, times: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
compute evaluated b-spline basis at given time points
|
| 77 |
+
:param times:
|
| 78 |
+
:return:
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
# Shape of times:
|
| 82 |
+
# [*add_dim, num_times]
|
| 83 |
+
#
|
| 84 |
+
# Shape of basis:
|
| 85 |
+
# [*add_dim, num_times, num_ctrlp]
|
| 86 |
+
|
| 87 |
+
# phase = self.phase_generator.phase(times)
|
| 88 |
+
phase = self.time2phase(times)
|
| 89 |
+
|
| 90 |
+
basis = [self._basis_function(i, self.degree_p, self.knots_vec, phase)
|
| 91 |
+
for i in range(self.num_ctrlp)]
|
| 92 |
+
basis = torch.stack(basis, dim=-1)
|
| 93 |
+
|
| 94 |
+
return basis
|
| 95 |
+
|
| 96 |
+
def _basis_function(self, i, k, knots, u, **kwargs):
|
| 97 |
+
"""
|
| 98 |
+
recursive construct of B-spline basis using de Boor's algorithm
|
| 99 |
+
|
| 100 |
+
:param i: basis index
|
| 101 |
+
:param k: degree
|
| 102 |
+
:param u: evaluate time point
|
| 103 |
+
:param knots: knots vector
|
| 104 |
+
:return: vector of shape [num_eval_points]
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
if k == 0:
|
| 108 |
+
num_ctrlp = kwargs.get("num_ctrlp", self.num_ctrlp)
|
| 109 |
+
if i == num_ctrlp - 1:
|
| 110 |
+
# with regard to original definition, each span is defined as \
|
| 111 |
+
# left closed and right open interval [v_i, v_i+1), which makes\
|
| 112 |
+
# the value at right end always 0. It is undesired,so that we \
|
| 113 |
+
# need to handle the last basis specially
|
| 114 |
+
b0 = torch.where((u >= knots[i]) & (u <= knots[i + 1]), 1, 0)
|
| 115 |
+
else:
|
| 116 |
+
b0 = torch.where((u >= knots[i]) & (u < knots[i + 1]), 1, 0)
|
| 117 |
+
return torch.as_tensor(b0, dtype=self.dtype, device=self.device)
|
| 118 |
+
else:
|
| 119 |
+
denom1 = knots[i + k] - knots[i]
|
| 120 |
+
term1 = 0.0 if denom1 == 0 else (u - knots[i]) / denom1 * \
|
| 121 |
+
self._basis_function(i, k - 1,
|
| 122 |
+
knots, u,
|
| 123 |
+
**kwargs)
|
| 124 |
+
denom2 = knots[i + k + 1] - knots[i + 1]
|
| 125 |
+
term2 = 0.0 if denom2 == 0 else (knots[i + k + 1] - u) / denom2 * \
|
| 126 |
+
self._basis_function(i + 1, k - 1,
|
| 127 |
+
knots, u,
|
| 128 |
+
**kwargs)
|
| 129 |
+
return term1 + term2
|
| 130 |
+
|
| 131 |
+
def vel_basis(self, times: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
"""
|
| 133 |
+
Directly get the basis of velocity B-spline
|
| 134 |
+
:param times:
|
| 135 |
+
:return:
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
# phase = self.phase_generator.phase(times)
|
| 139 |
+
phase = self.time2phase(times)
|
| 140 |
+
|
| 141 |
+
# for clamped uni B-spline
|
| 142 |
+
vel_nots_vec = self.knots_vec[1:-1]
|
| 143 |
+
basis = \
|
| 144 |
+
[self._basis_function(i, self.degree_p - 1, vel_nots_vec, phase,
|
| 145 |
+
num_ctrlp=self.num_ctrlp - 1)
|
| 146 |
+
for i in range(self.num_ctrlp - 1)]
|
| 147 |
+
basis = torch.stack(basis, dim=-1)
|
| 148 |
+
if self.goal_basis:
|
| 149 |
+
gb = torch.ones_like(phase, dtype=self.dtype, device=self.device)[
|
| 150 |
+
..., None]
|
| 151 |
+
basis = torch.cat([basis, gb], dim=-1)
|
| 152 |
+
return basis
|
| 153 |
+
|
| 154 |
+
def acc_basis(self, times: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
"""
|
| 156 |
+
Directly get the basis of acceleration B-spline
|
| 157 |
+
:param times:
|
| 158 |
+
:return:
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
# phase = self.phase_generator.phase(times)
|
| 162 |
+
phase = self.time2phase(times)
|
| 163 |
+
|
| 164 |
+
acc_knots_vec = self.knots_vec[2: -2]
|
| 165 |
+
|
| 166 |
+
basis = [
|
| 167 |
+
self._basis_function(i, self.degree_p - 2, acc_knots_vec, phase,
|
| 168 |
+
num_ctrlp=self.num_ctrlp - 2)
|
| 169 |
+
for i in range(self.num_ctrlp - 2)]
|
| 170 |
+
basis = torch.stack(basis, dim=-1)
|
| 171 |
+
|
| 172 |
+
return basis
|
| 173 |
+
|
| 174 |
+
def velocity_control_points(self, ctrl_pts: torch.Tensor):
|
| 175 |
+
"""
|
| 176 |
+
given the position control points (parameter), return the velocity control
|
| 177 |
+
points for vel B-spline as linear combination of position control points.
|
| 178 |
+
|
| 179 |
+
:param ctrl_pts: vector of position control points
|
| 180 |
+
:return: velocity control points
|
| 181 |
+
"""
|
| 182 |
+
# diff shape: [*add_dim, num_dof, num_ctrlp-1]
|
| 183 |
+
diff = ctrl_pts[..., 1:] - ctrl_pts[..., :-1]
|
| 184 |
+
# shape: [num_basis-1]
|
| 185 |
+
delta = self.knots_vec[
|
| 186 |
+
1 + self.degree_p: self.num_ctrlp + self.degree_p] - \
|
| 187 |
+
self.knots_vec[1: self.num_ctrlp]
|
| 188 |
+
diff = diff * (1 / delta)
|
| 189 |
+
return diff * self.degree_p
|
| 190 |
+
|
| 191 |
+
def acceleration_control_points(self, ctrl_pts: torch.Tensor):
|
| 192 |
+
"""
|
| 193 |
+
given the position control points (parameter), return the acceleration
|
| 194 |
+
control points for acc B-spline as linear combination of position
|
| 195 |
+
control points.
|
| 196 |
+
|
| 197 |
+
:param ctrl_pts: vector of position control points
|
| 198 |
+
:return: velocity control points
|
| 199 |
+
"""
|
| 200 |
+
# shape: [*add_dim, num_dof, num_ctrlp-1]
|
| 201 |
+
vel_ctrl_pts = self.velocity_control_points(ctrl_pts)
|
| 202 |
+
# shape: [*add_dim, num_dof, num_ctrlp-2]
|
| 203 |
+
diff = vel_ctrl_pts[..., 1:] - vel_ctrl_pts[..., :-1]
|
| 204 |
+
# shape: [num_ctrlp-2]
|
| 205 |
+
# delta = self.knots_vec[2+self.degree_p: self.num_ctrlp+self.degree_p-1]\
|
| 206 |
+
# - self.knots_vec[2: self.num_ctrlp-1]
|
| 207 |
+
delta = self.knots_vec[
|
| 208 |
+
2 + self.degree_p: self.num_ctrlp + self.degree_p] \
|
| 209 |
+
- self.knots_vec[2: self.num_ctrlp]
|
| 210 |
+
diff = diff * (1 / delta)
|
| 211 |
+
return diff * (self.degree_p - 1)
|
| 212 |
+
|
| 213 |
+
def compute_init_params(self, init_pos, init_vel, **kwargs):
|
| 214 |
+
"""
|
| 215 |
+
Given initial condition, compute corresponding the first control points
|
| 216 |
+
:param init_pos:
|
| 217 |
+
:param init_vel:
|
| 218 |
+
:param kwargs:
|
| 219 |
+
:return:
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
# Shape of init_pos:
|
| 223 |
+
# [*add_dim, num_dof]
|
| 224 |
+
#
|
| 225 |
+
# Shape of init_vel:
|
| 226 |
+
# [*add_dim, num_dof]
|
| 227 |
+
#
|
| 228 |
+
# return shape:
|
| 229 |
+
# [*add_dim, num_dof, init_cond_order]
|
| 230 |
+
|
| 231 |
+
if self.init_cond_order == 0:
|
| 232 |
+
return None
|
| 233 |
+
|
| 234 |
+
para_init_p = init_pos
|
| 235 |
+
para_init = para_init_p[..., None]
|
| 236 |
+
|
| 237 |
+
if self.init_cond_order == 2:
|
| 238 |
+
para_init_v = \
|
| 239 |
+
torch.einsum("...i,...->...i", init_vel,
|
| 240 |
+
self.tau) * \
|
| 241 |
+
(self.knots_vec[1 + self.degree_p] - self.knots_vec[1]) \
|
| 242 |
+
/ self.degree_p + para_init_p
|
| 243 |
+
para_init = torch.cat([para_init, para_init_v[..., None]], dim=-1)
|
| 244 |
+
|
| 245 |
+
return para_init
|
| 246 |
+
|
| 247 |
+
def compute_end_params(self, end_pos, end_vel, **kwargs):
|
| 248 |
+
"""
|
| 249 |
+
Given end condition, compute corresponding the last control points
|
| 250 |
+
:param end_pos:
|
| 251 |
+
:param end_vel:
|
| 252 |
+
:param kwargs:
|
| 253 |
+
:return:
|
| 254 |
+
"""
|
| 255 |
+
# Shape of end_pos:
|
| 256 |
+
# [*add_dim, num_dof]
|
| 257 |
+
#
|
| 258 |
+
# Shape of end_vel:
|
| 259 |
+
# [*add_dim, num_dof]
|
| 260 |
+
#
|
| 261 |
+
# return shape:
|
| 262 |
+
# [*add_dim, num_dof, init_cond_order]
|
| 263 |
+
|
| 264 |
+
if self.end_cond_order == 0:
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
para_end_p = end_pos
|
| 268 |
+
para_end = para_end_p[..., None]
|
| 269 |
+
|
| 270 |
+
if self.end_cond_order == 2:
|
| 271 |
+
para_end_v = para_end_p - \
|
| 272 |
+
torch.einsum("...i,...->...i", end_vel,
|
| 273 |
+
self.tau) * \
|
| 274 |
+
(self.knots_vec[self.num_ctrlp - 1 + self.degree_p] -
|
| 275 |
+
self.knots_vec[self.num_ctrlp - 1]) * self.degree_p
|
| 276 |
+
# para_end_v = para_end_p - (end_vel * self.phase_generator.tau) * \
|
| 277 |
+
# (self.knots_vec[self.num_ctrlp - 1 + self.degree_p] -
|
| 278 |
+
# self.knots_vec[self.num_ctrlp-1]) * self.degree_p
|
| 279 |
+
para_end = torch.cat([para_end_v[..., None], para_end], dim=-1)
|
| 280 |
+
|
| 281 |
+
return para_end
|
| 282 |
+
|
| 283 |
+
def basis_multi_dofs(self,
|
| 284 |
+
times: torch.Tensor,
|
| 285 |
+
num_dof: int) -> torch.Tensor:
|
| 286 |
+
"""
|
| 287 |
+
Interface to generate value of single basis function at given time
|
| 288 |
+
points
|
| 289 |
+
Args:
|
| 290 |
+
times: times in Tensor
|
| 291 |
+
num_dof: num of Degree of freedoms
|
| 292 |
+
Returns:
|
| 293 |
+
basis_multi_dofs: Multiple DoFs basis functions in Tensor
|
| 294 |
+
|
| 295 |
+
"""
|
| 296 |
+
# Shape of time
|
| 297 |
+
# [*add_dim, num_times]
|
| 298 |
+
#
|
| 299 |
+
# Shape of basis_multi_dofs
|
| 300 |
+
# [*add_dim, num_dof * num_times, num_dof * num_basis]
|
| 301 |
+
|
| 302 |
+
# Extract additional dimensions
|
| 303 |
+
add_dim = list(times.shape[:-1])
|
| 304 |
+
|
| 305 |
+
# Get single basis, shape: [*add_dim, num_times, num_ctrlp]
|
| 306 |
+
basis_single_dof = self.basis(times)
|
| 307 |
+
# num_times = basis_single_dof.shape[-2]
|
| 308 |
+
num_times = times.shape[-1]
|
| 309 |
+
|
| 310 |
+
# shape: [*add_dim, num_times, num_basis]
|
| 311 |
+
basis_single_dof_ = basis_single_dof[..., self.init_cond_order:
|
| 312 |
+
self.num_ctrlp - self.end_cond_order]
|
| 313 |
+
# Multiple Dofs, shape:
|
| 314 |
+
# [*add_dim, num_dof * num_times, num_dof * num_basis]
|
| 315 |
+
basis_multi_dofs = torch.zeros(*add_dim, num_dof * num_times,
|
| 316 |
+
num_dof * self.num_basis,
|
| 317 |
+
dtype=self.dtype,
|
| 318 |
+
device=self.device)
|
| 319 |
+
|
| 320 |
+
# Assemble
|
| 321 |
+
for i in range(num_dof):
|
| 322 |
+
row_indices = slice(i * num_times, (i + 1) * num_times)
|
| 323 |
+
col_indices = slice(i * self.num_basis, (i + 1) * self.num_basis)
|
| 324 |
+
basis_multi_dofs[..., row_indices, col_indices] = basis_single_dof_
|
| 325 |
+
|
| 326 |
+
# Return
|
| 327 |
+
return basis_multi_dofs
|
| 328 |
+
|
| 329 |
+
def show_basis(self, plot=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 330 |
+
"""
|
| 331 |
+
Compute basis function values for debug usage
|
| 332 |
+
The times are in the range of [delay - tau, delay + 2 * tau]
|
| 333 |
+
|
| 334 |
+
Returns: basis function values
|
| 335 |
+
|
| 336 |
+
"""
|
| 337 |
+
times = torch.linspace(0, 1, steps=1000)
|
| 338 |
+
basis_values = self.basis(times)
|
| 339 |
+
if plot:
|
| 340 |
+
import matplotlib.pyplot as plt
|
| 341 |
+
plt.figure()
|
| 342 |
+
for i in range(basis_values.shape[-1]):
|
| 343 |
+
plt.plot(times, basis_values[:, i], label=f"basis_{i}")
|
| 344 |
+
plt.grid()
|
| 345 |
+
plt.legend()
|
| 346 |
+
plt.axvline(x=0, linestyle='--', color='k', alpha=0.3)
|
| 347 |
+
plt.axvline(x=1, linestyle='--', color='k', alpha=0.3)
|
| 348 |
+
plt.show()
|
| 349 |
+
return times, basis_values
|
beast.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bspline_factory import SplineFactory
|
| 2 |
+
import torch
|
| 3 |
+
from addict import Dict
|
| 4 |
+
from .utils import continuous_to_discrete, discrete_to_continuous, normalize_tensor, denormalize_tensor, tensor_linspace
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import einops
|
| 8 |
+
|
| 9 |
+
from transformers.processing_utils import ProcessorMixin
|
| 10 |
+
|
| 11 |
+
from functools import wraps
|
| 12 |
+
|
| 13 |
+
def autocast_float32(fn):
|
| 14 |
+
@wraps(fn)
|
| 15 |
+
def wrapped(*args, **kwargs):
|
| 16 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 17 |
+
return fn(*args, **kwargs)
|
| 18 |
+
return wrapped
|
| 19 |
+
|
| 20 |
+
class BeastTokenizer(torch.nn.Module, ProcessorMixin):
|
| 21 |
+
"""
|
| 22 |
+
B-spline based tokenizer for trajectory encoding/decoding.
|
| 23 |
+
|
| 24 |
+
Converts continuous trajectories to discrete tokens and vice versa using B-splines.
|
| 25 |
+
Supports continuous and discrete representations of trajectories.
|
| 26 |
+
Supports sperate handling for continous action and discrete state (e.g., binarized gripper state).
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
# Class constants
|
| 30 |
+
DEFAULT_DT = 0.01 # 100 Hz sampling rate
|
| 31 |
+
|
| 32 |
+
def __init__(self, num_dof=1, num_basis=10, seq_len=50, vocab_size=256,
|
| 33 |
+
degree_p=4, gripper_zero_order=False, gripper_dof=1, init_cond_order=0,
|
| 34 |
+
end_cond_order=0, enforce_init_pos=True, device="cuda"):
|
| 35 |
+
super().__init__()
|
| 36 |
+
|
| 37 |
+
# Store core parameters
|
| 38 |
+
self.device = device
|
| 39 |
+
self.seq_length = seq_len
|
| 40 |
+
self.vocab_size = vocab_size
|
| 41 |
+
self.num_basis = num_basis
|
| 42 |
+
self.enforce_init_pos = enforce_init_pos
|
| 43 |
+
self.init_cond_order = init_cond_order
|
| 44 |
+
self.end_cond_order = end_cond_order
|
| 45 |
+
self.dt = self.DEFAULT_DT
|
| 46 |
+
self.init_pos = None
|
| 47 |
+
|
| 48 |
+
# Calculate DOF distribution
|
| 49 |
+
self.gripper_dof = gripper_dof if gripper_zero_order else 0
|
| 50 |
+
self.joint_dof = num_dof - self.gripper_dof
|
| 51 |
+
self.num_dof = self.joint_dof + self.gripper_dof
|
| 52 |
+
|
| 53 |
+
# Initialize spline components
|
| 54 |
+
self.bsp = self._create_bsplines(self.joint_dof, degree_p)
|
| 55 |
+
self.gripper_bsp = self._create_bsplines(self.gripper_dof, 0) if gripper_zero_order else None
|
| 56 |
+
|
| 57 |
+
# Setup time grid and weight bounds
|
| 58 |
+
# Working with normalized time [0, 1]
|
| 59 |
+
self.times = tensor_linspace(0, 1.0, seq_len).to(device)
|
| 60 |
+
self._initialize_weight_bounds()
|
| 61 |
+
|
| 62 |
+
self.to(self.device)
|
| 63 |
+
|
| 64 |
+
def _create_bsplines(self, num_dof, degree_p):
|
| 65 |
+
"""Create motion primitive for joint trajectories."""
|
| 66 |
+
config = Dict({
|
| 67 |
+
'mp_type': 'uni_bspline',
|
| 68 |
+
'device': self.device,
|
| 69 |
+
'num_dof': num_dof,
|
| 70 |
+
'tau': 1.0,
|
| 71 |
+
'mp_args': {
|
| 72 |
+
'num_basis': self.num_basis,
|
| 73 |
+
'degree_p': degree_p,
|
| 74 |
+
'init_condition_order': self.init_cond_order,
|
| 75 |
+
'end_condition_order': self.end_cond_order,
|
| 76 |
+
'dt': self.dt
|
| 77 |
+
}
|
| 78 |
+
})
|
| 79 |
+
return SplineFactory.init_splines(**config)
|
| 80 |
+
|
| 81 |
+
def _initialize_weight_bounds(self):
|
| 82 |
+
"""Initialize weight bounds for normalization."""
|
| 83 |
+
total_params = self.num_dof * self.num_basis
|
| 84 |
+
self.register_buffer("w_min", -1.0 * torch.ones(total_params))
|
| 85 |
+
self.register_buffer("w_max", 1.0 * torch.ones(total_params))
|
| 86 |
+
|
| 87 |
+
def _get_repeated_times(self, batch_size):
|
| 88 |
+
"""Get time tensor repeated for batch processing."""
|
| 89 |
+
return einops.repeat(self.times, 't -> b t', b=batch_size)
|
| 90 |
+
|
| 91 |
+
@autocast_float32
|
| 92 |
+
def _learn_trajectory_params(self, times, trajs):
|
| 93 |
+
"""Learn B-spline parameters from trajectories."""
|
| 94 |
+
# Learn joint parameters
|
| 95 |
+
joint_params = self.bsp.learn_mp_params_from_trajs(times, trajs[..., :self.joint_dof])
|
| 96 |
+
|
| 97 |
+
# Learn gripper parameters if applicable
|
| 98 |
+
if self.gripper_bsp is not None:
|
| 99 |
+
gripper_trajs = trajs[..., -self.gripper_dof:]
|
| 100 |
+
gripper_params = self.gripper_bsp.learn_mp_params_from_trajs(times, gripper_trajs)
|
| 101 |
+
joint_params['params'] = torch.cat([joint_params['params'], gripper_params['params']], dim=-1)
|
| 102 |
+
|
| 103 |
+
return joint_params
|
| 104 |
+
|
| 105 |
+
@autocast_float32
|
| 106 |
+
def _reconstruct_trajectory(self, params, times):
|
| 107 |
+
"""Reconstruct trajectory from B-spline parameters."""
|
| 108 |
+
# Reconstruct joint trajectory
|
| 109 |
+
joint_params = params[..., :self.joint_dof * self.num_basis]
|
| 110 |
+
self.bsp.update_inputs(times=times, params=joint_params)
|
| 111 |
+
position = self.bsp.get_traj_pos()
|
| 112 |
+
|
| 113 |
+
# Reconstruct gripper trajectory if applicable
|
| 114 |
+
if self.gripper_bsp is not None:
|
| 115 |
+
gripper_params = params[..., -self.gripper_dof * self.num_basis:]
|
| 116 |
+
self.gripper_bsp.update_inputs(times=times, params=gripper_params)
|
| 117 |
+
gripper_pos = self.gripper_bsp.get_traj_pos()
|
| 118 |
+
position = torch.cat([position, gripper_pos], dim=-1)
|
| 119 |
+
|
| 120 |
+
return position
|
| 121 |
+
|
| 122 |
+
def _apply_initial_position_constraint(self, params, init_pos):
|
| 123 |
+
"""Apply initial position constraint to parameters."""
|
| 124 |
+
if not self.init_pos or init_pos is None:
|
| 125 |
+
return params
|
| 126 |
+
|
| 127 |
+
# Reshape to access individual basis functions
|
| 128 |
+
reshaped_params = einops.rearrange(params, "b (d t) -> b t d", t=self.num_basis, d=self.num_dof)
|
| 129 |
+
|
| 130 |
+
# Set initial position for joint DOFs
|
| 131 |
+
reshaped_params[:, 0, :self.joint_dof] = init_pos[:, :self.joint_dof]
|
| 132 |
+
|
| 133 |
+
return einops.rearrange(reshaped_params, "b t d -> b (d t)")
|
| 134 |
+
|
| 135 |
+
@autocast_float32
|
| 136 |
+
def compute_weights(self, demos):
|
| 137 |
+
"""Compute B-spline weights from demonstration trajectories."""
|
| 138 |
+
times = self._get_repeated_times(demos.shape[0])
|
| 139 |
+
weights = self.bsp.learn_mp_params_from_trajs(times, demos)['params']
|
| 140 |
+
return weights
|
| 141 |
+
|
| 142 |
+
def update_weights_bounds_per_batch(self, weights):
|
| 143 |
+
"""Update weight bounds based on batch statistics."""
|
| 144 |
+
weights = weights.reshape(-1, self.num_dof * self.num_basis)
|
| 145 |
+
batch_min = weights.min(dim=0)[0]
|
| 146 |
+
batch_max = weights.max(dim=0)[0]
|
| 147 |
+
|
| 148 |
+
# Update bounds with small tolerance
|
| 149 |
+
tolerance = 1e-4
|
| 150 |
+
smaller_mask = batch_min < (self.w_min - tolerance)
|
| 151 |
+
larger_mask = batch_max > (self.w_max + tolerance)
|
| 152 |
+
|
| 153 |
+
if torch.any(smaller_mask):
|
| 154 |
+
self.w_min[smaller_mask] = batch_min[smaller_mask]
|
| 155 |
+
if torch.any(larger_mask):
|
| 156 |
+
self.w_max[larger_mask] = batch_max[larger_mask]
|
| 157 |
+
|
| 158 |
+
def update_times(self, times):
|
| 159 |
+
"""Update time grid."""
|
| 160 |
+
self.times = times
|
| 161 |
+
|
| 162 |
+
@torch.no_grad()
|
| 163 |
+
@autocast_float32
|
| 164 |
+
def encode_discrete(self, trajs, update_bounds=False, init_p=None):
|
| 165 |
+
"""Encode trajectories to discrete tokens."""
|
| 166 |
+
times = self._get_repeated_times(trajs.shape[0])
|
| 167 |
+
params_dict = self._learn_trajectory_params(times, trajs)
|
| 168 |
+
|
| 169 |
+
if update_bounds:
|
| 170 |
+
self.update_weights_bounds_per_batch(params_dict['params'])
|
| 171 |
+
|
| 172 |
+
# Clamp parameters to bounds
|
| 173 |
+
params = torch.clamp(params_dict['params'], min=self.w_min, max=self.w_max)
|
| 174 |
+
|
| 175 |
+
# Convert to discrete tokens
|
| 176 |
+
tokens = continuous_to_discrete(params, min_val=self.w_min, max_val=self.w_max, num_bins=self.vocab_size)
|
| 177 |
+
tokens = einops.rearrange(tokens, 'b (d t) -> b (t d)', t=self.num_basis, d=self.num_dof)
|
| 178 |
+
|
| 179 |
+
return tokens
|
| 180 |
+
|
| 181 |
+
@torch.no_grad()
|
| 182 |
+
@autocast_float32
|
| 183 |
+
def decode_discrete(self, tokens, times=None, init_pos=None):
|
| 184 |
+
"""Decode discrete tokens to trajectories."""
|
| 185 |
+
# Reshape tokens and convert to continuous parameters
|
| 186 |
+
tokens = einops.rearrange(tokens, 'b (t d) -> b (d t)', t=self.num_basis, d=self.num_dof)
|
| 187 |
+
params = discrete_to_continuous(tokens, min_val=self.w_min, max_val=self.w_max, num_bins=self.vocab_size)
|
| 188 |
+
|
| 189 |
+
if times is None:
|
| 190 |
+
times = self._get_repeated_times(params.shape[0])
|
| 191 |
+
|
| 192 |
+
# Apply initial position constraint if specified
|
| 193 |
+
params = self._apply_initial_position_constraint(params, init_pos)
|
| 194 |
+
|
| 195 |
+
return self._reconstruct_trajectory(params, times)
|
| 196 |
+
|
| 197 |
+
@torch.no_grad()
|
| 198 |
+
@autocast_float32
|
| 199 |
+
def encode_continuous(self, trajs, update_bounds=False):
|
| 200 |
+
"""Encode trajectories to continuous tokens (normalized parameters)."""
|
| 201 |
+
times = self._get_repeated_times(trajs.shape[0])
|
| 202 |
+
params_dict = self._learn_trajectory_params(times, trajs)
|
| 203 |
+
|
| 204 |
+
if update_bounds:
|
| 205 |
+
self.update_weights_bounds_per_batch(params_dict['params'])
|
| 206 |
+
|
| 207 |
+
# Normalize parameters
|
| 208 |
+
tokens = normalize_tensor(params_dict['params'], w_min=self.w_min, w_max=self.w_max)
|
| 209 |
+
|
| 210 |
+
return tokens
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
@autocast_float32
|
| 214 |
+
def decode_continuous(self, params, times=None, init_pos=None):
|
| 215 |
+
"""Decode continuous tokens (normalized parameters) to trajectories."""
|
| 216 |
+
# Denormalize parameters
|
| 217 |
+
params = denormalize_tensor(params, w_min=self.w_min, w_max=self.w_max)
|
| 218 |
+
|
| 219 |
+
if times is None:
|
| 220 |
+
times = self._get_repeated_times(params.shape[0])
|
| 221 |
+
|
| 222 |
+
# Apply initial position constraint if specified
|
| 223 |
+
params = self._apply_initial_position_constraint(params, init_pos)
|
| 224 |
+
|
| 225 |
+
return self._reconstruct_trajectory(params, times)
|
| 226 |
+
|
| 227 |
+
@autocast_float32
|
| 228 |
+
def compute_reconstruction_error(self, raw_traj):
|
| 229 |
+
"""Compute reconstruction error for trajectory."""
|
| 230 |
+
if len(raw_traj.shape) == 2:
|
| 231 |
+
raw_traj = raw_traj.unsqueeze(-1)
|
| 232 |
+
|
| 233 |
+
tokens, _ = self.encode_discrete(raw_traj)
|
| 234 |
+
reconstructed = self.decode_discrete(tokens)
|
| 235 |
+
error = torch.mean((raw_traj - reconstructed) ** 2)
|
| 236 |
+
|
| 237 |
+
return error
|
| 238 |
+
|
| 239 |
+
def _plot_trajectory_comparison(self, original, reconstructed, title_prefix=""):
|
| 240 |
+
"""Helper method to plot trajectory comparison."""
|
| 241 |
+
original = original.detach().cpu().numpy()
|
| 242 |
+
reconstructed = reconstructed.detach().cpu().numpy()
|
| 243 |
+
x_vals = np.linspace(0, 1.0, original.shape[1])
|
| 244 |
+
|
| 245 |
+
batch_size, time_steps, dof = original.shape
|
| 246 |
+
|
| 247 |
+
for sample_idx in range(batch_size):
|
| 248 |
+
fig, axes = plt.subplots(dof, 1, figsize=(8, 2 * dof), sharex=True)
|
| 249 |
+
if dof == 1:
|
| 250 |
+
axes = [axes] # Handle single DOF case
|
| 251 |
+
|
| 252 |
+
for i in range(dof):
|
| 253 |
+
axes[i].plot(x_vals, reconstructed[sample_idx, :, i],
|
| 254 |
+
marker='o', label='Reconstructed', linestyle='-', color='b')
|
| 255 |
+
axes[i].plot(x_vals, original[sample_idx, :, i],
|
| 256 |
+
marker='*', label='Ground Truth', linestyle='--', color='r')
|
| 257 |
+
axes[i].set_ylabel(f"DOF {i + 1}")
|
| 258 |
+
axes[i].grid(True)
|
| 259 |
+
axes[i].legend(loc="best")
|
| 260 |
+
|
| 261 |
+
axes[-1].set_xlabel("Time (s)")
|
| 262 |
+
plt.suptitle(f"{title_prefix}Trajectory Comparison - Sample {sample_idx}")
|
| 263 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| 264 |
+
plt.show()
|
| 265 |
+
|
| 266 |
+
def visualize_reconstruction_error_discrete(self, raw_traj):
|
| 267 |
+
"""Visualize reconstruction error for discrete encoding."""
|
| 268 |
+
tokens = self.encode_discrete(raw_traj, update_bounds=True)
|
| 269 |
+
reconstructed = self.decode_discrete(tokens)
|
| 270 |
+
self._plot_trajectory_comparison(raw_traj, reconstructed, "Discrete ")
|
| 271 |
+
|
| 272 |
+
def visualize_reconstruction_error_continuous(self, raw_traj):
|
| 273 |
+
"""Visualize reconstruction error for continuous encoding."""
|
| 274 |
+
raw_traj = raw_traj.to(torch.float32)
|
| 275 |
+
if len(raw_traj.shape) == 2:
|
| 276 |
+
raw_traj = raw_traj.unsqueeze(0)
|
| 277 |
+
|
| 278 |
+
continuous_tokens = self.encode_continuous(raw_traj, update_bounds=True)
|
| 279 |
+
reconstructed = self.decode_continuous(continuous_tokens)
|
| 280 |
+
self._plot_trajectory_comparison(raw_traj, reconstructed, "Continuous ")
|
bspline_factory.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .basis_gn import UniBSplineBasis
|
| 4 |
+
from .uni_bspline import UniformBSpline
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SplineFactory:
|
| 8 |
+
|
| 9 |
+
@staticmethod
|
| 10 |
+
def init_splines(mp_type: str,
|
| 11 |
+
mp_args: dict,
|
| 12 |
+
num_dof: int = 1,
|
| 13 |
+
tau: float = 1,
|
| 14 |
+
dtype: torch.dtype = torch.float32,
|
| 15 |
+
device: torch.device = "cpu"):
|
| 16 |
+
|
| 17 |
+
if mp_type == "uni_bspline":
|
| 18 |
+
basis_gn = UniBSplineBasis(dtype=dtype, device=device, tau=tau,
|
| 19 |
+
**mp_args)
|
| 20 |
+
mp = UniformBSpline(basis_gn=basis_gn, num_dof=num_dof,
|
| 21 |
+
dtype=dtype, device=device, **mp_args)
|
| 22 |
+
else:
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
|
| 25 |
+
return mp
|
uni_bspline.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import Union, Optional
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .basis_gn import UniBSplineBasis
|
| 9 |
+
|
| 10 |
+
class UniformBSpline(torch.nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self,
|
| 13 |
+
basis_gn: UniBSplineBasis,
|
| 14 |
+
num_dof: int,
|
| 15 |
+
weights_scale: float = 1.,
|
| 16 |
+
dtype: torch.dtype = torch.float32,
|
| 17 |
+
device: torch.device = 'cpu',
|
| 18 |
+
**kwargs,
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
# self.dtype = dtype
|
| 23 |
+
# self.device = device
|
| 24 |
+
# batch dim
|
| 25 |
+
self.add_dim = list()
|
| 26 |
+
|
| 27 |
+
self.basis_gn = basis_gn
|
| 28 |
+
self.num_dof = num_dof
|
| 29 |
+
|
| 30 |
+
# Scaling of weights
|
| 31 |
+
weights_scale = \
|
| 32 |
+
torch.tensor(weights_scale, dtype=self.dtype, device=self.device)
|
| 33 |
+
assert weights_scale.ndim <= 1, \
|
| 34 |
+
"weights_scale should be float or 1-dim vector"
|
| 35 |
+
self.register_buffer("weights_scale", weights_scale, persistent=False)
|
| 36 |
+
|
| 37 |
+
# Value caches
|
| 38 |
+
# Compute values at these time points
|
| 39 |
+
self.times = None
|
| 40 |
+
|
| 41 |
+
# Learnable parameters
|
| 42 |
+
self.params = None
|
| 43 |
+
|
| 44 |
+
# Initial conditions
|
| 45 |
+
self.init_pos = None
|
| 46 |
+
self.init_vel = None
|
| 47 |
+
|
| 48 |
+
# Runtime computation results, shall be reset every time when
|
| 49 |
+
# inputs are reset
|
| 50 |
+
self.pos = None
|
| 51 |
+
self.vel = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
#parameters bound
|
| 55 |
+
# params_bound = kwargs.get("params_bound", None)
|
| 56 |
+
# if not params_bound:
|
| 57 |
+
# params_bound = torch.zeros([2, self.num_params],
|
| 58 |
+
# dtype=self.dtype,
|
| 59 |
+
# device=self.device)
|
| 60 |
+
# params_bound[0, :] = -torch.inf
|
| 61 |
+
# params_bound[1, :] = torch.inf
|
| 62 |
+
# else:
|
| 63 |
+
# params_bound = torch.as_tensor(self.params_bound,
|
| 64 |
+
# dtype=self.dtype,
|
| 65 |
+
# device=self.device)
|
| 66 |
+
# assert list(params_bound.shape) == [2, self.num_params]
|
| 67 |
+
# self.register_buffer("params_bound", params_bound, persistent=False)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
self.end_pos = None
|
| 71 |
+
self.end_vel = None
|
| 72 |
+
|
| 73 |
+
self.params_init = None
|
| 74 |
+
self.params_end = None
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def device(self):
|
| 78 |
+
return self.basis_gn.device
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def dtype(self):
|
| 82 |
+
return self.basis_gn.dtype
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def tau(self):
|
| 86 |
+
return self.basis_gn.tau
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def num_basis(self):
|
| 90 |
+
return self.basis_gn.num_basis
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def num_params(self):
|
| 94 |
+
return self.basis_gn.num_basis * self.num_dof
|
| 95 |
+
|
| 96 |
+
def clear_computation_result(self):
|
| 97 |
+
"""
|
| 98 |
+
Clear runtime computation result
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
None
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
self.pos = None
|
| 105 |
+
self.vel = None
|
| 106 |
+
# also reset tau?
|
| 107 |
+
|
| 108 |
+
def set_add_dim(self, add_dim: Union[list, torch.Size]):
|
| 109 |
+
"""
|
| 110 |
+
Set additional batch dimension
|
| 111 |
+
Args:
|
| 112 |
+
add_dim: additional batch dimension
|
| 113 |
+
|
| 114 |
+
Returns: None
|
| 115 |
+
|
| 116 |
+
"""
|
| 117 |
+
self.add_dim = add_dim
|
| 118 |
+
self.clear_computation_result()
|
| 119 |
+
|
| 120 |
+
def set_times(self, times: Union[torch.Tensor, np.ndarray]):
|
| 121 |
+
"""
|
| 122 |
+
Set time points
|
| 123 |
+
Args:
|
| 124 |
+
times: time points
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
None
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
# Shape of times
|
| 131 |
+
# [*add_dim, num_times]
|
| 132 |
+
|
| 133 |
+
self.times = torch.as_tensor(times, dtype=self.dtype,
|
| 134 |
+
device=self.device)
|
| 135 |
+
tau = times.reshape(-1)[-1]
|
| 136 |
+
self.basis_gn.tau.copy_(tau)
|
| 137 |
+
self.clear_computation_result()
|
| 138 |
+
|
| 139 |
+
def set_duration(self, duration: Optional[float], dt: float,):
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
duration: desired duration of trajectory
|
| 144 |
+
dt: control frequency
|
| 145 |
+
Returns:
|
| 146 |
+
None
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
# Shape of times
|
| 150 |
+
# [*add_dim, num_times]
|
| 151 |
+
dt = torch.as_tensor(dt, dtype=self.dtype, device=self.device)
|
| 152 |
+
times = torch.linspace(0, duration, round(duration / dt) + 1,
|
| 153 |
+
dtype=self.dtype, device=self.device)
|
| 154 |
+
times = add_expand_dim(times, list(range(len(self.add_dim))),
|
| 155 |
+
self.add_dim)
|
| 156 |
+
self.set_times(times)
|
| 157 |
+
|
| 158 |
+
def set_params(self,
|
| 159 |
+
params: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
| 160 |
+
"""
|
| 161 |
+
Set MP params
|
| 162 |
+
Args:
|
| 163 |
+
params: parameters
|
| 164 |
+
|
| 165 |
+
Returns: unused parameters
|
| 166 |
+
|
| 167 |
+
"""
|
| 168 |
+
# Shape of params
|
| 169 |
+
# [*add_dim, num_params]
|
| 170 |
+
|
| 171 |
+
params = torch.as_tensor(params, dtype=self.dtype, device=self.device)
|
| 172 |
+
|
| 173 |
+
# Check number of params
|
| 174 |
+
assert params.shape[-1] == self.num_params
|
| 175 |
+
|
| 176 |
+
# Set additional batch size
|
| 177 |
+
self.set_add_dim(list(params.shape[:-1]))
|
| 178 |
+
|
| 179 |
+
self.params = params[..., :self.num_params]
|
| 180 |
+
self.clear_computation_result()
|
| 181 |
+
return params[..., self.num_params:]
|
| 182 |
+
|
| 183 |
+
def update_inputs(self, times=None, params=None,
|
| 184 |
+
init_pos=None, init_vel=None, **kwargs):
|
| 185 |
+
|
| 186 |
+
if params is not None:
|
| 187 |
+
self.set_params(params)
|
| 188 |
+
if times is not None:
|
| 189 |
+
self.set_times(times)
|
| 190 |
+
if init_pos is not None:
|
| 191 |
+
self.set_initial_conditions(init_pos, init_vel, **kwargs)
|
| 192 |
+
|
| 193 |
+
end_pos = kwargs.get('end_pos', None)
|
| 194 |
+
end_vel = kwargs.get('end_vel', None)
|
| 195 |
+
if any([cond is not None for cond in [end_pos, end_vel]]):
|
| 196 |
+
self.set_end_condtions(end_pos, end_vel)
|
| 197 |
+
|
| 198 |
+
def set_initial_conditions(self,
|
| 199 |
+
init_pos: Union[torch.Tensor, np.ndarray],
|
| 200 |
+
init_vel: Union[torch.Tensor, np.ndarray],
|
| 201 |
+
**kwargs):
|
| 202 |
+
|
| 203 |
+
self.init_pos = torch.as_tensor(init_pos, dtype=self.dtype,
|
| 204 |
+
device=self.device)
|
| 205 |
+
self.init_vel = torch.as_tensor(init_vel, dtype=self.dtype,
|
| 206 |
+
device=self.device) if init_vel is not None else None
|
| 207 |
+
self.clear_computation_result()
|
| 208 |
+
|
| 209 |
+
self.params_init = self.basis_gn.compute_init_params(self.init_pos, self.init_vel)
|
| 210 |
+
if self.params_init is not None:
|
| 211 |
+
self.params_init /= self.weights_scale
|
| 212 |
+
|
| 213 |
+
def set_end_condtions(self, end_pos: Union[torch.Tensor, np.ndarray],
|
| 214 |
+
end_vel: Union[torch.Tensor, np.ndarray], **kwargs):
|
| 215 |
+
self.end_pos = \
|
| 216 |
+
torch.as_tensor(end_pos, device=self.device, dtype=self.dtype) \
|
| 217 |
+
if end_pos is not None else None
|
| 218 |
+
self.end_vel = \
|
| 219 |
+
torch.as_tensor(end_vel, device=self.device, dtype=self.dtype) \
|
| 220 |
+
if end_vel is not None else None
|
| 221 |
+
|
| 222 |
+
self.params_end = self.basis_gn.compute_end_params(self.end_pos, self.end_vel)
|
| 223 |
+
if self.params_end is not None:
|
| 224 |
+
self.params_end /= self.weights_scale
|
| 225 |
+
|
| 226 |
+
def get_traj_pos(self, times=None, params=None,
|
| 227 |
+
init_pos=None, init_vel=None, flat_shape=False, **kwargs):
|
| 228 |
+
|
| 229 |
+
self.update_inputs(times, params, init_pos, init_vel, **kwargs)
|
| 230 |
+
|
| 231 |
+
if self.pos is not None:
|
| 232 |
+
pos = self.pos
|
| 233 |
+
else:
|
| 234 |
+
assert self.params is not None
|
| 235 |
+
|
| 236 |
+
# Reshape params
|
| 237 |
+
# [*add_dim, num_dof * num_basis] -> [*add_dim, num_dof, num_basis]
|
| 238 |
+
params = self.params.reshape(*self.add_dim, self.num_dof, -1)
|
| 239 |
+
# extend params with possible init and end conditions
|
| 240 |
+
# shape: [*add_dim, num_dof, num_ctrlp]
|
| 241 |
+
if self.params_init is not None:
|
| 242 |
+
params = torch.cat((self.params_init, params), dim=-1)
|
| 243 |
+
if self.params_end is not None:
|
| 244 |
+
params = torch.cat((params, self.params_end), dim=-1)
|
| 245 |
+
|
| 246 |
+
# Get basis
|
| 247 |
+
# Shape: [*add_dim, num_times, num_ctrlp]
|
| 248 |
+
basis_single_dof = \
|
| 249 |
+
self.basis_gn.basis(self.times) * self.weights_scale
|
| 250 |
+
|
| 251 |
+
# Einsum shape: [*add_dim, num_times, num_ctrlp],
|
| 252 |
+
# [*add_dim, num_dof, num_ctrlp]
|
| 253 |
+
# -> [*add_dim, num_times, num_dof]
|
| 254 |
+
pos = torch.einsum('...ik,...jk->...ij', basis_single_dof, params)
|
| 255 |
+
|
| 256 |
+
self.pos = pos
|
| 257 |
+
|
| 258 |
+
if flat_shape:
|
| 259 |
+
# Switch axes to [*add_dim, num_dof, num_times]
|
| 260 |
+
pos = torch.einsum('...ji->...ij', pos)
|
| 261 |
+
|
| 262 |
+
# Reshape to [*add_dim, num_dof * num_times]
|
| 263 |
+
pos = pos.reshape(*self.add_dim, -1)
|
| 264 |
+
|
| 265 |
+
return pos
|
| 266 |
+
|
| 267 |
+
def get_traj_vel(self, times=None, params=None,
|
| 268 |
+
init_pos=None, init_vel=None, flat_shape=False, **kwargs):
|
| 269 |
+
|
| 270 |
+
self.update_inputs(times, params, init_pos, init_vel,
|
| 271 |
+
**kwargs)
|
| 272 |
+
|
| 273 |
+
if self.vel is not None:
|
| 274 |
+
vel = self.vel
|
| 275 |
+
else:
|
| 276 |
+
assert self.params is not None
|
| 277 |
+
|
| 278 |
+
# Reshape params
|
| 279 |
+
# [*add_dim, num_dof * num_basis] -> [*add_dim, num_dof, num_basis]
|
| 280 |
+
params = self.params.reshape(*self.add_dim, self.num_dof, -1)
|
| 281 |
+
# extend params with possible init and end conditions
|
| 282 |
+
# shape: [*add_dim, num_dof, num_ctrlp]
|
| 283 |
+
if self.params_init is not None:
|
| 284 |
+
params = torch.cat((self.params_init, params), dim=-1)
|
| 285 |
+
if self.params_end is not None:
|
| 286 |
+
params = torch.cat((params, self.params_end), dim=-1)
|
| 287 |
+
|
| 288 |
+
# velocity control points shape: [*add_dim, num_dof, num_ctrlp-1]
|
| 289 |
+
vel_ctrlp = self.basis_gn.velocity_control_points(params)
|
| 290 |
+
vel_ctrlp = torch.einsum("...ij,...->...ij", vel_ctrlp,
|
| 291 |
+
1 / self.tau)
|
| 292 |
+
|
| 293 |
+
# vel_basis shape: [*add_dim, num_times, num_ctrlp-1]
|
| 294 |
+
vel_basis = self.basis_gn.vel_basis(self.times) * self.weights_scale
|
| 295 |
+
|
| 296 |
+
# Einsum shape: [*add_dim, num_times, num_ctrlp-1],
|
| 297 |
+
# [*add_dim, num_dof, num_ctrlp-1]
|
| 298 |
+
# -> [*add_dim, num_times, num_dof]
|
| 299 |
+
vel = torch.einsum('...ik,...jk->...ij', vel_basis, vel_ctrlp)
|
| 300 |
+
|
| 301 |
+
self.vel = vel
|
| 302 |
+
|
| 303 |
+
if flat_shape:
|
| 304 |
+
# Switch axes to [*add_dim, num_dof, num_times]
|
| 305 |
+
vel = torch.einsum('...ji->...ij', vel)
|
| 306 |
+
|
| 307 |
+
# Reshape to [*add_dim, num_dof * num_times]
|
| 308 |
+
vel = vel.reshape(*self.add_dim, -1)
|
| 309 |
+
|
| 310 |
+
return vel
|
| 311 |
+
|
| 312 |
+
def learn_mp_params_from_trajs(self, times: torch.Tensor,
|
| 313 |
+
trajs: torch.Tensor, reg=1e-5, **kwargs):
|
| 314 |
+
|
| 315 |
+
# only works for learn_tau=False, learn_delay=False. And delay=0 (or you
|
| 316 |
+
# need to give the initial condition by yourself)
|
| 317 |
+
|
| 318 |
+
# Shape of times
|
| 319 |
+
# [*add_dim, num_times]
|
| 320 |
+
#
|
| 321 |
+
# Shape of trajs:
|
| 322 |
+
# [*add_dim, num_times, num_dof]
|
| 323 |
+
#
|
| 324 |
+
# Shape of params:
|
| 325 |
+
# [*add_dim, num_dof * num_basis]
|
| 326 |
+
|
| 327 |
+
assert trajs.shape[:-1] == times.shape
|
| 328 |
+
assert trajs.shape[-1] == self.num_dof
|
| 329 |
+
|
| 330 |
+
times = torch.as_tensor(times, dtype=self.dtype, device=self.device)
|
| 331 |
+
trajs = torch.as_tensor(trajs, dtype=self.dtype, device=self.device)
|
| 332 |
+
|
| 333 |
+
# Setup stuff
|
| 334 |
+
self.set_add_dim(list(trajs.shape[:-2]))
|
| 335 |
+
self.set_times(times)
|
| 336 |
+
dummy_params = torch.zeros(*self.add_dim, self.num_dof, self.num_basis,
|
| 337 |
+
device=self.device, dtype=self.dtype)
|
| 338 |
+
|
| 339 |
+
# Get initial conditions
|
| 340 |
+
if self.basis_gn.init_cond_order != 0:
|
| 341 |
+
if any([key in kwargs.keys()
|
| 342 |
+
for key in [ "init_pos", "init_vel"]]):
|
| 343 |
+
logging.warning("uses the given initial conditions")
|
| 344 |
+
init_pos = kwargs.get("init_pos")
|
| 345 |
+
init_vel = kwargs.get("init_vel")
|
| 346 |
+
else:
|
| 347 |
+
init_pos = trajs[..., 0, :]
|
| 348 |
+
dt = (times[..., 1] - times[..., 0])
|
| 349 |
+
init_vel = torch.einsum("...i,...->...i",
|
| 350 |
+
torch.diff(trajs, dim=-2)[..., 0, :],
|
| 351 |
+
1/dt)
|
| 352 |
+
self.set_initial_conditions(init_pos, init_vel)
|
| 353 |
+
if self.params_init is not None:
|
| 354 |
+
dummy_params = torch.cat([self.params_init, dummy_params],
|
| 355 |
+
dim=-1)
|
| 356 |
+
|
| 357 |
+
if self.basis_gn.end_cond_order != 0:
|
| 358 |
+
if any([key in kwargs.keys()
|
| 359 |
+
for key in ["end_pos", "end_vel"]]):
|
| 360 |
+
logging.warning("uses the given end conditions")
|
| 361 |
+
end_pos = kwargs.get("end_pos")
|
| 362 |
+
end_vel = kwargs.get("end_vel")
|
| 363 |
+
else:
|
| 364 |
+
end_pos = trajs[..., -1, :]
|
| 365 |
+
dt = (times[..., 1] - times[..., 0])
|
| 366 |
+
end_vel = torch.einsum("...i,...->...i",
|
| 367 |
+
torch.diff(trajs, dim=-2)[..., -1, :],
|
| 368 |
+
1/dt)
|
| 369 |
+
self.set_end_condtions(end_pos, end_vel)
|
| 370 |
+
if self.params_end is not None:
|
| 371 |
+
dummy_params = torch.cat([dummy_params, self.params_end],
|
| 372 |
+
dim=-1)
|
| 373 |
+
|
| 374 |
+
basis_single_dof = self.basis_gn.basis(times) * self.weights_scale
|
| 375 |
+
# shape: [*add_dim, num_time, num_ctrlp]
|
| 376 |
+
# [*add_dim, num_dof, num_ctrlp]
|
| 377 |
+
# [*add_dim, num_times, num_dof]
|
| 378 |
+
pos_det = torch.einsum('...ik,...jk->...ij', basis_single_dof, dummy_params)
|
| 379 |
+
# swtich axes to [*add_dim, num_dof, num_times]
|
| 380 |
+
pos_det = torch.einsum('...ij->...ji', pos_det)
|
| 381 |
+
pos_det = pos_det.reshape(*self.add_dim, -1)
|
| 382 |
+
|
| 383 |
+
basis_multi_dofs = self.basis_gn.basis_multi_dofs(self.times, self.num_dof) * self.weights_scale
|
| 384 |
+
# Solve this: Aw = B -> w = A^{-1} B
|
| 385 |
+
# Einsum_shape: [*add_dim, num_dof * num_times, num_dof * num_basis]
|
| 386 |
+
# [*add_dim, num_dof * num_times, num_dof * num_basis]
|
| 387 |
+
# -> [*add_dim, num_dof * num_basis, num_dof * num_basis]
|
| 388 |
+
A = torch.einsum('...ki,...kj->...ij', basis_multi_dofs,
|
| 389 |
+
basis_multi_dofs)
|
| 390 |
+
A += torch.eye(self.num_params,
|
| 391 |
+
dtype=self.dtype,
|
| 392 |
+
device=self.device) * reg
|
| 393 |
+
|
| 394 |
+
# Swap axis and reshape: [*add_dim, num_times, num_dof]
|
| 395 |
+
# -> [*add_dim, num_dof, num_times]
|
| 396 |
+
trajs = torch.einsum("...ij->...ji", trajs)
|
| 397 |
+
# Reshape [*add_dim, num_dof, num_times]
|
| 398 |
+
# -> [*add_dim, num_dof * num_times]
|
| 399 |
+
trajs = trajs.reshape([*self.add_dim, -1])
|
| 400 |
+
|
| 401 |
+
# Position minus initial condition terms,
|
| 402 |
+
pos_w = trajs - pos_det
|
| 403 |
+
|
| 404 |
+
# Einsum_shape: [*add_dim, num_dof * num_times, num_dof * num_basis]
|
| 405 |
+
# [*add_dim, num_dof * num_times]
|
| 406 |
+
# -> [*add_dim, num_dof * num_basis]
|
| 407 |
+
B = torch.einsum('...ki,...k->...i', basis_multi_dofs, pos_w)
|
| 408 |
+
|
| 409 |
+
# Shape of weights: [*add_dim, num_dof * num_basis]
|
| 410 |
+
params = torch.linalg.solve(A, B)
|
| 411 |
+
|
| 412 |
+
self.set_params(params)
|
| 413 |
+
|
| 414 |
+
return {"params": params,
|
| 415 |
+
"init_pos": self.init_pos,
|
| 416 |
+
"init_vel": self.init_vel,
|
| 417 |
+
"end_pos": self.end_pos,
|
| 418 |
+
"end_vel": self.end_vel,
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def add_expand_dim(data: Union[torch.Tensor, np.ndarray],
|
| 423 |
+
add_dim_indices: [int],
|
| 424 |
+
add_dim_sizes: [int]) -> Union[torch.Tensor, np.ndarray]:
|
| 425 |
+
"""
|
| 426 |
+
Add additional dimensions to tensor and expand accordingly
|
| 427 |
+
Args:
|
| 428 |
+
data: tensor to be operated. Torch.Tensor or numpy.ndarray
|
| 429 |
+
add_dim_indices: the indices of added dimensions in the result tensor
|
| 430 |
+
add_dim_sizes: the expanding size of the additional dimensions
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
result: result tensor after adding and expanding
|
| 434 |
+
"""
|
| 435 |
+
num_data_dim = data.ndim
|
| 436 |
+
num_dim_to_add = len(add_dim_indices)
|
| 437 |
+
|
| 438 |
+
add_dim_reverse_indices = [num_data_dim + num_dim_to_add + idx for idx in
|
| 439 |
+
add_dim_indices]
|
| 440 |
+
|
| 441 |
+
str_add_dim = ""
|
| 442 |
+
str_expand = ""
|
| 443 |
+
add_dim_index = 0
|
| 444 |
+
for dim in range(num_data_dim + num_dim_to_add):
|
| 445 |
+
if dim in add_dim_indices or dim in add_dim_reverse_indices:
|
| 446 |
+
str_add_dim += "None, "
|
| 447 |
+
str_expand += str(add_dim_sizes[add_dim_index]) + ", "
|
| 448 |
+
add_dim_index += 1
|
| 449 |
+
else:
|
| 450 |
+
str_add_dim += ":, "
|
| 451 |
+
if type(data) == torch.Tensor:
|
| 452 |
+
str_expand += "-1, "
|
| 453 |
+
elif type(data) == np.ndarray:
|
| 454 |
+
str_expand += "1, "
|
| 455 |
+
else:
|
| 456 |
+
raise NotImplementedError
|
| 457 |
+
|
| 458 |
+
str_add_dime_eval = "data[" + str_add_dim + "]"
|
| 459 |
+
if type(data) == torch.Tensor:
|
| 460 |
+
return eval("eval(str_add_dime_eval).expand(" + str_expand + ")")
|
| 461 |
+
else:
|
| 462 |
+
return eval("np.tile(eval(str_add_dime_eval),[" + str_expand + "])")
|
utils.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import einops
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
def continuous_to_discrete(tensor, min_val=None, max_val=None, num_bins=256):
|
| 6 |
+
"""
|
| 7 |
+
Convert a continuous PyTorch tensor to discrete tokens in the range [0, 255].
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
tensor (torch.Tensor): Input tensor with continuous values.
|
| 11 |
+
min_val (float, optional): Minimum value for normalization. If None, use tensor.min().
|
| 12 |
+
max_val (float, optional): Maximum value for normalization. If None, use tensor.max().
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
torch.Tensor: Discretized tensor with values in the range [0, 255].
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
if min_val is None:
|
| 19 |
+
min_val = tensor.min()
|
| 20 |
+
if max_val is None:
|
| 21 |
+
max_val = tensor.max()
|
| 22 |
+
|
| 23 |
+
# Normalize the tensor to [0, 1]
|
| 24 |
+
assert torch.all(tensor >= min_val - 1e-3), "Input tensor has values below min_val"
|
| 25 |
+
assert torch.all(tensor <= max_val + 1e-3), "Input tensor has values above max_val"
|
| 26 |
+
normalized_tensor = (tensor - min_val) / (max_val - min_val)
|
| 27 |
+
normalized_tensor = torch.clamp(normalized_tensor, 0, 1)
|
| 28 |
+
|
| 29 |
+
# Ensure no out-of-bound values
|
| 30 |
+
# Scale to [0, 255] and quantize to integers
|
| 31 |
+
discrete_tensor = torch.round(normalized_tensor * (num_bins-1)).to(torch.long)
|
| 32 |
+
return discrete_tensor
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def discrete_to_continuous(discrete_tensor, min_val=0, max_val=1, num_bins=256):
|
| 36 |
+
"""
|
| 37 |
+
Convert a discrete PyTorch tensor with values in the range [0, 255]
|
| 38 |
+
back to continuous values in the range [min_val, max_val].
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
discrete_tensor (torch.Tensor): Input tensor with discrete values (0 to 255).
|
| 42 |
+
min_val (float): Minimum value of the original continuous range.
|
| 43 |
+
max_val (float): Maximum value of the original continuous range.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
torch.Tensor: Continuous tensor with values in the range [min_val, max_val].
|
| 47 |
+
"""
|
| 48 |
+
# Map discrete tokens to [0, 1]
|
| 49 |
+
# Normalize the tensor to [0, 1]
|
| 50 |
+
normalized_tensor = discrete_tensor.float() / (num_bins-1)
|
| 51 |
+
|
| 52 |
+
# Map normalized values to [min_val, max_val]
|
| 53 |
+
continuous_tensor = normalized_tensor * (max_val - min_val) + min_val
|
| 54 |
+
|
| 55 |
+
# Ensure no out-of-bound values
|
| 56 |
+
continuous_tensor = torch.clamp(continuous_tensor, min_val, max_val)
|
| 57 |
+
return continuous_tensor
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def normalize_tensor(tensor, w_min, w_max, norm_min=-1.0, norm_max=1.0):
|
| 61 |
+
"""
|
| 62 |
+
Normalize a tensor from its original range [w_min, w_max] to a new range [norm_min, norm_max].
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
tensor (torch.Tensor): Input tensor to be normalized
|
| 66 |
+
w_min (float): Minimum value bound of the original tensor
|
| 67 |
+
w_max (float): Maximum value bound of the original tensor
|
| 68 |
+
norm_min (float, optional): Minimum value of the normalized range. Defaults to 0.0.
|
| 69 |
+
norm_max (float, optional): Maximum value of the normalized range. Defaults to 1.0.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
torch.Tensor: Normalized tensor with values in range [norm_min, norm_max]
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
# Clip the input tensor to be within [w_min, w_max]
|
| 76 |
+
clipped_tensor = torch.clamp(tensor, w_min, w_max)
|
| 77 |
+
|
| 78 |
+
# Normalize to [0, 1] range first
|
| 79 |
+
normalized = (clipped_tensor - w_min) / (w_max - w_min)
|
| 80 |
+
|
| 81 |
+
# Scale to the desired [norm_min, norm_max] range
|
| 82 |
+
normalized = normalized * (norm_max - norm_min) + norm_min
|
| 83 |
+
|
| 84 |
+
return normalized
|
| 85 |
+
|
| 86 |
+
def denormalize_tensor(normalized_tensor, w_min, w_max, norm_min=-1.0, norm_max=1.0):
|
| 87 |
+
"""
|
| 88 |
+
Denormalize a tensor from the normalized range [norm_min, norm_max] back to the original range [w_min, w_max].
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
normalized_tensor (torch.Tensor): Normalized input tensor
|
| 92 |
+
w_min (float): Minimum value bound of the original range
|
| 93 |
+
w_max (float): Maximum value bound of the original range
|
| 94 |
+
norm_min (float, optional): Minimum value of the normalized range. Defaults to 0.0.
|
| 95 |
+
norm_max (float, optional): Maximum value of the normalized range. Defaults to 1.0.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
torch.Tensor: Denormalized tensor with values in range [w_min, w_max]
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
# Clip the normalized tensor to be within [norm_min, norm_max]
|
| 102 |
+
clipped_tensor = torch.clamp(normalized_tensor, norm_min, norm_max)
|
| 103 |
+
|
| 104 |
+
# Scale from [norm_min, norm_max] to [0, 1] first
|
| 105 |
+
denormalized = (clipped_tensor - norm_min) / (norm_max - norm_min)
|
| 106 |
+
|
| 107 |
+
# Scale to the original [w_min, w_max] range
|
| 108 |
+
denormalized = denormalized * (w_max - w_min) + w_min
|
| 109 |
+
|
| 110 |
+
return denormalized
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def tensor_linspace(start: Union[float, int, torch.Tensor],
|
| 114 |
+
end: Union[float, int, torch.Tensor],
|
| 115 |
+
steps: int) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
Vectorized version of torch.linspace.
|
| 118 |
+
Modified from:
|
| 119 |
+
https://github.com/zhaobozb/layout2im/blob/master/models/bilinear.py#L246
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
start: start value, scalar or tensor
|
| 123 |
+
end: end value, scalar or tensor
|
| 124 |
+
steps: num of steps
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
linspace tensor
|
| 128 |
+
"""
|
| 129 |
+
# Shape of start:
|
| 130 |
+
# [*add_dim, dim_data] or a scalar
|
| 131 |
+
#
|
| 132 |
+
# Shape of end:
|
| 133 |
+
# [*add_dim, dim_data] or a scalar
|
| 134 |
+
#
|
| 135 |
+
# Shape of out:
|
| 136 |
+
# [*add_dim, steps, dim_data]
|
| 137 |
+
|
| 138 |
+
# - out: Tensor of shape start.size() + (steps,), such that
|
| 139 |
+
# out.select(-1, 0) == start, out.select(-1, -1) == end,
|
| 140 |
+
# and the other elements of out linearly interpolate between
|
| 141 |
+
# start and end.
|
| 142 |
+
|
| 143 |
+
if isinstance(start, torch.Tensor) and not isinstance(end, torch.Tensor):
|
| 144 |
+
end += torch.zeros_like(start)
|
| 145 |
+
elif not isinstance(start, torch.Tensor) and isinstance(end, torch.Tensor):
|
| 146 |
+
start += torch.zeros_like(end)
|
| 147 |
+
elif isinstance(start, torch.Tensor) and isinstance(end, torch.Tensor):
|
| 148 |
+
assert start.size() == end.size()
|
| 149 |
+
else:
|
| 150 |
+
return torch.linspace(start, end, steps)
|
| 151 |
+
|
| 152 |
+
view_size = start.size() + (1,)
|
| 153 |
+
w_size = (1,) * start.dim() + (steps,)
|
| 154 |
+
out_size = start.size() + (steps,)
|
| 155 |
+
|
| 156 |
+
start_w = torch.linspace(1, 0, steps=steps).to(start)
|
| 157 |
+
start_w = start_w.view(w_size).expand(out_size)
|
| 158 |
+
end_w = torch.linspace(0, 1, steps=steps).to(start)
|
| 159 |
+
end_w = end_w.view(w_size).expand(out_size)
|
| 160 |
+
|
| 161 |
+
start = start.contiguous().view(view_size).expand(out_size)
|
| 162 |
+
end = end.contiguous().view(view_size).expand(out_size)
|
| 163 |
+
|
| 164 |
+
out = start_w * start + end_w * end
|
| 165 |
+
out = torch.einsum('...ji->...ij', out)
|
| 166 |
+
return out
|