zhouhongyi commited on
Commit
475ce7c
·
1 Parent(s): 6894a19

intial commint

Browse files
Files changed (5) hide show
  1. basis_gn.py +349 -0
  2. beast.py +280 -0
  3. bspline_factory.py +25 -0
  4. uni_bspline.py +462 -0
  5. utils.py +166 -0
basis_gn.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @brief: Basis generators in PyTorch
3
+ """
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+
9
+ class UniBSplineBasis(torch.nn.Module):
10
+
11
+ def __init__(self,
12
+ num_basis: int = 10,
13
+ degree_p: int = 3,
14
+ dtype: torch.dtype = torch.float32,
15
+ device: torch.device = 'cpu',
16
+ **kwargs):
17
+ """
18
+ Constructor for basis class
19
+ Args:
20
+ num_basis: number of basis functions
21
+ dtype: torch data type
22
+ device: torch device to run on
23
+ """
24
+ super().__init__()
25
+
26
+ # Internal number of basis
27
+ self.num_basis = num_basis
28
+
29
+ self.degree_p = degree_p
30
+ self.init_cond_order = kwargs.get("init_condition_order", 0)
31
+ self.end_cond_order = kwargs.get("end_condition_order", 0)
32
+
33
+ self.num_ctrlp = num_basis + self.init_cond_order + self.end_cond_order
34
+ # number of knots needed, with respect to B-sp degree and number of
35
+ # control points ( num_basis + init_cond_order+end_cond_order)
36
+ num_knots = self.degree_p + 1 + self.num_ctrlp
37
+ num_knots_non_rep_inside_1 = num_knots - 2 * self.degree_p
38
+ # uniform knots vector
39
+ knots_vec = torch.linspace(0, 1, num_knots_non_rep_inside_1,
40
+ dtype=dtype, device=device)
41
+ knots_prev = torch.zeros(self.degree_p, dtype=dtype, device=device)
42
+ knots_pro = torch.ones(self.degree_p, dtype=dtype, device=device)
43
+ knots_vec = torch.cat([knots_prev, knots_vec, knots_pro])
44
+ self.register_buffer("knots_vec", knots_vec, persistent=False)
45
+
46
+ tau = kwargs.get("tau")
47
+ self.register_buffer('tau',
48
+ torch.tensor(tau, dtype=dtype, device=device),
49
+ persistent=False)
50
+
51
+ @property
52
+ def device(self):
53
+ return self.knots_vec.device
54
+
55
+ @property
56
+ def dtype(self):
57
+ return self.knots_vec.dtype
58
+
59
+ def time2phase(self, times: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ scaling time into [0,1] range phase
62
+ :param times:
63
+ :return:
64
+ """
65
+ # Shape of times:
66
+ # [*add_dim, num_times]
67
+
68
+ # tau = times[..., -1]
69
+ tau = times.reshape(-1)[-1]
70
+ self.tau.copy_(tau)
71
+ phase = torch.clip(times / self.tau[..., None], 0, 1)
72
+ return phase
73
+
74
+ def basis(self, times: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ compute evaluated b-spline basis at given time points
77
+ :param times:
78
+ :return:
79
+ """
80
+
81
+ # Shape of times:
82
+ # [*add_dim, num_times]
83
+ #
84
+ # Shape of basis:
85
+ # [*add_dim, num_times, num_ctrlp]
86
+
87
+ # phase = self.phase_generator.phase(times)
88
+ phase = self.time2phase(times)
89
+
90
+ basis = [self._basis_function(i, self.degree_p, self.knots_vec, phase)
91
+ for i in range(self.num_ctrlp)]
92
+ basis = torch.stack(basis, dim=-1)
93
+
94
+ return basis
95
+
96
+ def _basis_function(self, i, k, knots, u, **kwargs):
97
+ """
98
+ recursive construct of B-spline basis using de Boor's algorithm
99
+
100
+ :param i: basis index
101
+ :param k: degree
102
+ :param u: evaluate time point
103
+ :param knots: knots vector
104
+ :return: vector of shape [num_eval_points]
105
+ """
106
+
107
+ if k == 0:
108
+ num_ctrlp = kwargs.get("num_ctrlp", self.num_ctrlp)
109
+ if i == num_ctrlp - 1:
110
+ # with regard to original definition, each span is defined as \
111
+ # left closed and right open interval [v_i, v_i+1), which makes\
112
+ # the value at right end always 0. It is undesired,so that we \
113
+ # need to handle the last basis specially
114
+ b0 = torch.where((u >= knots[i]) & (u <= knots[i + 1]), 1, 0)
115
+ else:
116
+ b0 = torch.where((u >= knots[i]) & (u < knots[i + 1]), 1, 0)
117
+ return torch.as_tensor(b0, dtype=self.dtype, device=self.device)
118
+ else:
119
+ denom1 = knots[i + k] - knots[i]
120
+ term1 = 0.0 if denom1 == 0 else (u - knots[i]) / denom1 * \
121
+ self._basis_function(i, k - 1,
122
+ knots, u,
123
+ **kwargs)
124
+ denom2 = knots[i + k + 1] - knots[i + 1]
125
+ term2 = 0.0 if denom2 == 0 else (knots[i + k + 1] - u) / denom2 * \
126
+ self._basis_function(i + 1, k - 1,
127
+ knots, u,
128
+ **kwargs)
129
+ return term1 + term2
130
+
131
+ def vel_basis(self, times: torch.Tensor) -> torch.Tensor:
132
+ """
133
+ Directly get the basis of velocity B-spline
134
+ :param times:
135
+ :return:
136
+ """
137
+
138
+ # phase = self.phase_generator.phase(times)
139
+ phase = self.time2phase(times)
140
+
141
+ # for clamped uni B-spline
142
+ vel_nots_vec = self.knots_vec[1:-1]
143
+ basis = \
144
+ [self._basis_function(i, self.degree_p - 1, vel_nots_vec, phase,
145
+ num_ctrlp=self.num_ctrlp - 1)
146
+ for i in range(self.num_ctrlp - 1)]
147
+ basis = torch.stack(basis, dim=-1)
148
+ if self.goal_basis:
149
+ gb = torch.ones_like(phase, dtype=self.dtype, device=self.device)[
150
+ ..., None]
151
+ basis = torch.cat([basis, gb], dim=-1)
152
+ return basis
153
+
154
+ def acc_basis(self, times: torch.Tensor) -> torch.Tensor:
155
+ """
156
+ Directly get the basis of acceleration B-spline
157
+ :param times:
158
+ :return:
159
+ """
160
+
161
+ # phase = self.phase_generator.phase(times)
162
+ phase = self.time2phase(times)
163
+
164
+ acc_knots_vec = self.knots_vec[2: -2]
165
+
166
+ basis = [
167
+ self._basis_function(i, self.degree_p - 2, acc_knots_vec, phase,
168
+ num_ctrlp=self.num_ctrlp - 2)
169
+ for i in range(self.num_ctrlp - 2)]
170
+ basis = torch.stack(basis, dim=-1)
171
+
172
+ return basis
173
+
174
+ def velocity_control_points(self, ctrl_pts: torch.Tensor):
175
+ """
176
+ given the position control points (parameter), return the velocity control
177
+ points for vel B-spline as linear combination of position control points.
178
+
179
+ :param ctrl_pts: vector of position control points
180
+ :return: velocity control points
181
+ """
182
+ # diff shape: [*add_dim, num_dof, num_ctrlp-1]
183
+ diff = ctrl_pts[..., 1:] - ctrl_pts[..., :-1]
184
+ # shape: [num_basis-1]
185
+ delta = self.knots_vec[
186
+ 1 + self.degree_p: self.num_ctrlp + self.degree_p] - \
187
+ self.knots_vec[1: self.num_ctrlp]
188
+ diff = diff * (1 / delta)
189
+ return diff * self.degree_p
190
+
191
+ def acceleration_control_points(self, ctrl_pts: torch.Tensor):
192
+ """
193
+ given the position control points (parameter), return the acceleration
194
+ control points for acc B-spline as linear combination of position
195
+ control points.
196
+
197
+ :param ctrl_pts: vector of position control points
198
+ :return: velocity control points
199
+ """
200
+ # shape: [*add_dim, num_dof, num_ctrlp-1]
201
+ vel_ctrl_pts = self.velocity_control_points(ctrl_pts)
202
+ # shape: [*add_dim, num_dof, num_ctrlp-2]
203
+ diff = vel_ctrl_pts[..., 1:] - vel_ctrl_pts[..., :-1]
204
+ # shape: [num_ctrlp-2]
205
+ # delta = self.knots_vec[2+self.degree_p: self.num_ctrlp+self.degree_p-1]\
206
+ # - self.knots_vec[2: self.num_ctrlp-1]
207
+ delta = self.knots_vec[
208
+ 2 + self.degree_p: self.num_ctrlp + self.degree_p] \
209
+ - self.knots_vec[2: self.num_ctrlp]
210
+ diff = diff * (1 / delta)
211
+ return diff * (self.degree_p - 1)
212
+
213
+ def compute_init_params(self, init_pos, init_vel, **kwargs):
214
+ """
215
+ Given initial condition, compute corresponding the first control points
216
+ :param init_pos:
217
+ :param init_vel:
218
+ :param kwargs:
219
+ :return:
220
+ """
221
+
222
+ # Shape of init_pos:
223
+ # [*add_dim, num_dof]
224
+ #
225
+ # Shape of init_vel:
226
+ # [*add_dim, num_dof]
227
+ #
228
+ # return shape:
229
+ # [*add_dim, num_dof, init_cond_order]
230
+
231
+ if self.init_cond_order == 0:
232
+ return None
233
+
234
+ para_init_p = init_pos
235
+ para_init = para_init_p[..., None]
236
+
237
+ if self.init_cond_order == 2:
238
+ para_init_v = \
239
+ torch.einsum("...i,...->...i", init_vel,
240
+ self.tau) * \
241
+ (self.knots_vec[1 + self.degree_p] - self.knots_vec[1]) \
242
+ / self.degree_p + para_init_p
243
+ para_init = torch.cat([para_init, para_init_v[..., None]], dim=-1)
244
+
245
+ return para_init
246
+
247
+ def compute_end_params(self, end_pos, end_vel, **kwargs):
248
+ """
249
+ Given end condition, compute corresponding the last control points
250
+ :param end_pos:
251
+ :param end_vel:
252
+ :param kwargs:
253
+ :return:
254
+ """
255
+ # Shape of end_pos:
256
+ # [*add_dim, num_dof]
257
+ #
258
+ # Shape of end_vel:
259
+ # [*add_dim, num_dof]
260
+ #
261
+ # return shape:
262
+ # [*add_dim, num_dof, init_cond_order]
263
+
264
+ if self.end_cond_order == 0:
265
+ return None
266
+
267
+ para_end_p = end_pos
268
+ para_end = para_end_p[..., None]
269
+
270
+ if self.end_cond_order == 2:
271
+ para_end_v = para_end_p - \
272
+ torch.einsum("...i,...->...i", end_vel,
273
+ self.tau) * \
274
+ (self.knots_vec[self.num_ctrlp - 1 + self.degree_p] -
275
+ self.knots_vec[self.num_ctrlp - 1]) * self.degree_p
276
+ # para_end_v = para_end_p - (end_vel * self.phase_generator.tau) * \
277
+ # (self.knots_vec[self.num_ctrlp - 1 + self.degree_p] -
278
+ # self.knots_vec[self.num_ctrlp-1]) * self.degree_p
279
+ para_end = torch.cat([para_end_v[..., None], para_end], dim=-1)
280
+
281
+ return para_end
282
+
283
+ def basis_multi_dofs(self,
284
+ times: torch.Tensor,
285
+ num_dof: int) -> torch.Tensor:
286
+ """
287
+ Interface to generate value of single basis function at given time
288
+ points
289
+ Args:
290
+ times: times in Tensor
291
+ num_dof: num of Degree of freedoms
292
+ Returns:
293
+ basis_multi_dofs: Multiple DoFs basis functions in Tensor
294
+
295
+ """
296
+ # Shape of time
297
+ # [*add_dim, num_times]
298
+ #
299
+ # Shape of basis_multi_dofs
300
+ # [*add_dim, num_dof * num_times, num_dof * num_basis]
301
+
302
+ # Extract additional dimensions
303
+ add_dim = list(times.shape[:-1])
304
+
305
+ # Get single basis, shape: [*add_dim, num_times, num_ctrlp]
306
+ basis_single_dof = self.basis(times)
307
+ # num_times = basis_single_dof.shape[-2]
308
+ num_times = times.shape[-1]
309
+
310
+ # shape: [*add_dim, num_times, num_basis]
311
+ basis_single_dof_ = basis_single_dof[..., self.init_cond_order:
312
+ self.num_ctrlp - self.end_cond_order]
313
+ # Multiple Dofs, shape:
314
+ # [*add_dim, num_dof * num_times, num_dof * num_basis]
315
+ basis_multi_dofs = torch.zeros(*add_dim, num_dof * num_times,
316
+ num_dof * self.num_basis,
317
+ dtype=self.dtype,
318
+ device=self.device)
319
+
320
+ # Assemble
321
+ for i in range(num_dof):
322
+ row_indices = slice(i * num_times, (i + 1) * num_times)
323
+ col_indices = slice(i * self.num_basis, (i + 1) * self.num_basis)
324
+ basis_multi_dofs[..., row_indices, col_indices] = basis_single_dof_
325
+
326
+ # Return
327
+ return basis_multi_dofs
328
+
329
+ def show_basis(self, plot=False) -> Tuple[torch.Tensor, torch.Tensor]:
330
+ """
331
+ Compute basis function values for debug usage
332
+ The times are in the range of [delay - tau, delay + 2 * tau]
333
+
334
+ Returns: basis function values
335
+
336
+ """
337
+ times = torch.linspace(0, 1, steps=1000)
338
+ basis_values = self.basis(times)
339
+ if plot:
340
+ import matplotlib.pyplot as plt
341
+ plt.figure()
342
+ for i in range(basis_values.shape[-1]):
343
+ plt.plot(times, basis_values[:, i], label=f"basis_{i}")
344
+ plt.grid()
345
+ plt.legend()
346
+ plt.axvline(x=0, linestyle='--', color='k', alpha=0.3)
347
+ plt.axvline(x=1, linestyle='--', color='k', alpha=0.3)
348
+ plt.show()
349
+ return times, basis_values
beast.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bspline_factory import SplineFactory
2
+ import torch
3
+ from addict import Dict
4
+ from .utils import continuous_to_discrete, discrete_to_continuous, normalize_tensor, denormalize_tensor, tensor_linspace
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import einops
8
+
9
+ from transformers.processing_utils import ProcessorMixin
10
+
11
+ from functools import wraps
12
+
13
+ def autocast_float32(fn):
14
+ @wraps(fn)
15
+ def wrapped(*args, **kwargs):
16
+ with torch.cuda.amp.autocast(dtype=torch.float32):
17
+ return fn(*args, **kwargs)
18
+ return wrapped
19
+
20
+ class BeastTokenizer(torch.nn.Module, ProcessorMixin):
21
+ """
22
+ B-spline based tokenizer for trajectory encoding/decoding.
23
+
24
+ Converts continuous trajectories to discrete tokens and vice versa using B-splines.
25
+ Supports continuous and discrete representations of trajectories.
26
+ Supports sperate handling for continous action and discrete state (e.g., binarized gripper state).
27
+ """
28
+
29
+ # Class constants
30
+ DEFAULT_DT = 0.01 # 100 Hz sampling rate
31
+
32
+ def __init__(self, num_dof=1, num_basis=10, seq_len=50, vocab_size=256,
33
+ degree_p=4, gripper_zero_order=False, gripper_dof=1, init_cond_order=0,
34
+ end_cond_order=0, enforce_init_pos=True, device="cuda"):
35
+ super().__init__()
36
+
37
+ # Store core parameters
38
+ self.device = device
39
+ self.seq_length = seq_len
40
+ self.vocab_size = vocab_size
41
+ self.num_basis = num_basis
42
+ self.enforce_init_pos = enforce_init_pos
43
+ self.init_cond_order = init_cond_order
44
+ self.end_cond_order = end_cond_order
45
+ self.dt = self.DEFAULT_DT
46
+ self.init_pos = None
47
+
48
+ # Calculate DOF distribution
49
+ self.gripper_dof = gripper_dof if gripper_zero_order else 0
50
+ self.joint_dof = num_dof - self.gripper_dof
51
+ self.num_dof = self.joint_dof + self.gripper_dof
52
+
53
+ # Initialize spline components
54
+ self.bsp = self._create_bsplines(self.joint_dof, degree_p)
55
+ self.gripper_bsp = self._create_bsplines(self.gripper_dof, 0) if gripper_zero_order else None
56
+
57
+ # Setup time grid and weight bounds
58
+ # Working with normalized time [0, 1]
59
+ self.times = tensor_linspace(0, 1.0, seq_len).to(device)
60
+ self._initialize_weight_bounds()
61
+
62
+ self.to(self.device)
63
+
64
+ def _create_bsplines(self, num_dof, degree_p):
65
+ """Create motion primitive for joint trajectories."""
66
+ config = Dict({
67
+ 'mp_type': 'uni_bspline',
68
+ 'device': self.device,
69
+ 'num_dof': num_dof,
70
+ 'tau': 1.0,
71
+ 'mp_args': {
72
+ 'num_basis': self.num_basis,
73
+ 'degree_p': degree_p,
74
+ 'init_condition_order': self.init_cond_order,
75
+ 'end_condition_order': self.end_cond_order,
76
+ 'dt': self.dt
77
+ }
78
+ })
79
+ return SplineFactory.init_splines(**config)
80
+
81
+ def _initialize_weight_bounds(self):
82
+ """Initialize weight bounds for normalization."""
83
+ total_params = self.num_dof * self.num_basis
84
+ self.register_buffer("w_min", -1.0 * torch.ones(total_params))
85
+ self.register_buffer("w_max", 1.0 * torch.ones(total_params))
86
+
87
+ def _get_repeated_times(self, batch_size):
88
+ """Get time tensor repeated for batch processing."""
89
+ return einops.repeat(self.times, 't -> b t', b=batch_size)
90
+
91
+ @autocast_float32
92
+ def _learn_trajectory_params(self, times, trajs):
93
+ """Learn B-spline parameters from trajectories."""
94
+ # Learn joint parameters
95
+ joint_params = self.bsp.learn_mp_params_from_trajs(times, trajs[..., :self.joint_dof])
96
+
97
+ # Learn gripper parameters if applicable
98
+ if self.gripper_bsp is not None:
99
+ gripper_trajs = trajs[..., -self.gripper_dof:]
100
+ gripper_params = self.gripper_bsp.learn_mp_params_from_trajs(times, gripper_trajs)
101
+ joint_params['params'] = torch.cat([joint_params['params'], gripper_params['params']], dim=-1)
102
+
103
+ return joint_params
104
+
105
+ @autocast_float32
106
+ def _reconstruct_trajectory(self, params, times):
107
+ """Reconstruct trajectory from B-spline parameters."""
108
+ # Reconstruct joint trajectory
109
+ joint_params = params[..., :self.joint_dof * self.num_basis]
110
+ self.bsp.update_inputs(times=times, params=joint_params)
111
+ position = self.bsp.get_traj_pos()
112
+
113
+ # Reconstruct gripper trajectory if applicable
114
+ if self.gripper_bsp is not None:
115
+ gripper_params = params[..., -self.gripper_dof * self.num_basis:]
116
+ self.gripper_bsp.update_inputs(times=times, params=gripper_params)
117
+ gripper_pos = self.gripper_bsp.get_traj_pos()
118
+ position = torch.cat([position, gripper_pos], dim=-1)
119
+
120
+ return position
121
+
122
+ def _apply_initial_position_constraint(self, params, init_pos):
123
+ """Apply initial position constraint to parameters."""
124
+ if not self.init_pos or init_pos is None:
125
+ return params
126
+
127
+ # Reshape to access individual basis functions
128
+ reshaped_params = einops.rearrange(params, "b (d t) -> b t d", t=self.num_basis, d=self.num_dof)
129
+
130
+ # Set initial position for joint DOFs
131
+ reshaped_params[:, 0, :self.joint_dof] = init_pos[:, :self.joint_dof]
132
+
133
+ return einops.rearrange(reshaped_params, "b t d -> b (d t)")
134
+
135
+ @autocast_float32
136
+ def compute_weights(self, demos):
137
+ """Compute B-spline weights from demonstration trajectories."""
138
+ times = self._get_repeated_times(demos.shape[0])
139
+ weights = self.bsp.learn_mp_params_from_trajs(times, demos)['params']
140
+ return weights
141
+
142
+ def update_weights_bounds_per_batch(self, weights):
143
+ """Update weight bounds based on batch statistics."""
144
+ weights = weights.reshape(-1, self.num_dof * self.num_basis)
145
+ batch_min = weights.min(dim=0)[0]
146
+ batch_max = weights.max(dim=0)[0]
147
+
148
+ # Update bounds with small tolerance
149
+ tolerance = 1e-4
150
+ smaller_mask = batch_min < (self.w_min - tolerance)
151
+ larger_mask = batch_max > (self.w_max + tolerance)
152
+
153
+ if torch.any(smaller_mask):
154
+ self.w_min[smaller_mask] = batch_min[smaller_mask]
155
+ if torch.any(larger_mask):
156
+ self.w_max[larger_mask] = batch_max[larger_mask]
157
+
158
+ def update_times(self, times):
159
+ """Update time grid."""
160
+ self.times = times
161
+
162
+ @torch.no_grad()
163
+ @autocast_float32
164
+ def encode_discrete(self, trajs, update_bounds=False, init_p=None):
165
+ """Encode trajectories to discrete tokens."""
166
+ times = self._get_repeated_times(trajs.shape[0])
167
+ params_dict = self._learn_trajectory_params(times, trajs)
168
+
169
+ if update_bounds:
170
+ self.update_weights_bounds_per_batch(params_dict['params'])
171
+
172
+ # Clamp parameters to bounds
173
+ params = torch.clamp(params_dict['params'], min=self.w_min, max=self.w_max)
174
+
175
+ # Convert to discrete tokens
176
+ tokens = continuous_to_discrete(params, min_val=self.w_min, max_val=self.w_max, num_bins=self.vocab_size)
177
+ tokens = einops.rearrange(tokens, 'b (d t) -> b (t d)', t=self.num_basis, d=self.num_dof)
178
+
179
+ return tokens
180
+
181
+ @torch.no_grad()
182
+ @autocast_float32
183
+ def decode_discrete(self, tokens, times=None, init_pos=None):
184
+ """Decode discrete tokens to trajectories."""
185
+ # Reshape tokens and convert to continuous parameters
186
+ tokens = einops.rearrange(tokens, 'b (t d) -> b (d t)', t=self.num_basis, d=self.num_dof)
187
+ params = discrete_to_continuous(tokens, min_val=self.w_min, max_val=self.w_max, num_bins=self.vocab_size)
188
+
189
+ if times is None:
190
+ times = self._get_repeated_times(params.shape[0])
191
+
192
+ # Apply initial position constraint if specified
193
+ params = self._apply_initial_position_constraint(params, init_pos)
194
+
195
+ return self._reconstruct_trajectory(params, times)
196
+
197
+ @torch.no_grad()
198
+ @autocast_float32
199
+ def encode_continuous(self, trajs, update_bounds=False):
200
+ """Encode trajectories to continuous tokens (normalized parameters)."""
201
+ times = self._get_repeated_times(trajs.shape[0])
202
+ params_dict = self._learn_trajectory_params(times, trajs)
203
+
204
+ if update_bounds:
205
+ self.update_weights_bounds_per_batch(params_dict['params'])
206
+
207
+ # Normalize parameters
208
+ tokens = normalize_tensor(params_dict['params'], w_min=self.w_min, w_max=self.w_max)
209
+
210
+ return tokens
211
+
212
+ @torch.no_grad()
213
+ @autocast_float32
214
+ def decode_continuous(self, params, times=None, init_pos=None):
215
+ """Decode continuous tokens (normalized parameters) to trajectories."""
216
+ # Denormalize parameters
217
+ params = denormalize_tensor(params, w_min=self.w_min, w_max=self.w_max)
218
+
219
+ if times is None:
220
+ times = self._get_repeated_times(params.shape[0])
221
+
222
+ # Apply initial position constraint if specified
223
+ params = self._apply_initial_position_constraint(params, init_pos)
224
+
225
+ return self._reconstruct_trajectory(params, times)
226
+
227
+ @autocast_float32
228
+ def compute_reconstruction_error(self, raw_traj):
229
+ """Compute reconstruction error for trajectory."""
230
+ if len(raw_traj.shape) == 2:
231
+ raw_traj = raw_traj.unsqueeze(-1)
232
+
233
+ tokens, _ = self.encode_discrete(raw_traj)
234
+ reconstructed = self.decode_discrete(tokens)
235
+ error = torch.mean((raw_traj - reconstructed) ** 2)
236
+
237
+ return error
238
+
239
+ def _plot_trajectory_comparison(self, original, reconstructed, title_prefix=""):
240
+ """Helper method to plot trajectory comparison."""
241
+ original = original.detach().cpu().numpy()
242
+ reconstructed = reconstructed.detach().cpu().numpy()
243
+ x_vals = np.linspace(0, 1.0, original.shape[1])
244
+
245
+ batch_size, time_steps, dof = original.shape
246
+
247
+ for sample_idx in range(batch_size):
248
+ fig, axes = plt.subplots(dof, 1, figsize=(8, 2 * dof), sharex=True)
249
+ if dof == 1:
250
+ axes = [axes] # Handle single DOF case
251
+
252
+ for i in range(dof):
253
+ axes[i].plot(x_vals, reconstructed[sample_idx, :, i],
254
+ marker='o', label='Reconstructed', linestyle='-', color='b')
255
+ axes[i].plot(x_vals, original[sample_idx, :, i],
256
+ marker='*', label='Ground Truth', linestyle='--', color='r')
257
+ axes[i].set_ylabel(f"DOF {i + 1}")
258
+ axes[i].grid(True)
259
+ axes[i].legend(loc="best")
260
+
261
+ axes[-1].set_xlabel("Time (s)")
262
+ plt.suptitle(f"{title_prefix}Trajectory Comparison - Sample {sample_idx}")
263
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
264
+ plt.show()
265
+
266
+ def visualize_reconstruction_error_discrete(self, raw_traj):
267
+ """Visualize reconstruction error for discrete encoding."""
268
+ tokens = self.encode_discrete(raw_traj, update_bounds=True)
269
+ reconstructed = self.decode_discrete(tokens)
270
+ self._plot_trajectory_comparison(raw_traj, reconstructed, "Discrete ")
271
+
272
+ def visualize_reconstruction_error_continuous(self, raw_traj):
273
+ """Visualize reconstruction error for continuous encoding."""
274
+ raw_traj = raw_traj.to(torch.float32)
275
+ if len(raw_traj.shape) == 2:
276
+ raw_traj = raw_traj.unsqueeze(0)
277
+
278
+ continuous_tokens = self.encode_continuous(raw_traj, update_bounds=True)
279
+ reconstructed = self.decode_continuous(continuous_tokens)
280
+ self._plot_trajectory_comparison(raw_traj, reconstructed, "Continuous ")
bspline_factory.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .basis_gn import UniBSplineBasis
4
+ from .uni_bspline import UniformBSpline
5
+
6
+
7
+ class SplineFactory:
8
+
9
+ @staticmethod
10
+ def init_splines(mp_type: str,
11
+ mp_args: dict,
12
+ num_dof: int = 1,
13
+ tau: float = 1,
14
+ dtype: torch.dtype = torch.float32,
15
+ device: torch.device = "cpu"):
16
+
17
+ if mp_type == "uni_bspline":
18
+ basis_gn = UniBSplineBasis(dtype=dtype, device=device, tau=tau,
19
+ **mp_args)
20
+ mp = UniformBSpline(basis_gn=basis_gn, num_dof=num_dof,
21
+ dtype=dtype, device=device, **mp_args)
22
+ else:
23
+ raise NotImplementedError
24
+
25
+ return mp
uni_bspline.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Union, Optional
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from .basis_gn import UniBSplineBasis
9
+
10
+ class UniformBSpline(torch.nn.Module):
11
+
12
+ def __init__(self,
13
+ basis_gn: UniBSplineBasis,
14
+ num_dof: int,
15
+ weights_scale: float = 1.,
16
+ dtype: torch.dtype = torch.float32,
17
+ device: torch.device = 'cpu',
18
+ **kwargs,
19
+ ):
20
+ super().__init__()
21
+
22
+ # self.dtype = dtype
23
+ # self.device = device
24
+ # batch dim
25
+ self.add_dim = list()
26
+
27
+ self.basis_gn = basis_gn
28
+ self.num_dof = num_dof
29
+
30
+ # Scaling of weights
31
+ weights_scale = \
32
+ torch.tensor(weights_scale, dtype=self.dtype, device=self.device)
33
+ assert weights_scale.ndim <= 1, \
34
+ "weights_scale should be float or 1-dim vector"
35
+ self.register_buffer("weights_scale", weights_scale, persistent=False)
36
+
37
+ # Value caches
38
+ # Compute values at these time points
39
+ self.times = None
40
+
41
+ # Learnable parameters
42
+ self.params = None
43
+
44
+ # Initial conditions
45
+ self.init_pos = None
46
+ self.init_vel = None
47
+
48
+ # Runtime computation results, shall be reset every time when
49
+ # inputs are reset
50
+ self.pos = None
51
+ self.vel = None
52
+
53
+
54
+ #parameters bound
55
+ # params_bound = kwargs.get("params_bound", None)
56
+ # if not params_bound:
57
+ # params_bound = torch.zeros([2, self.num_params],
58
+ # dtype=self.dtype,
59
+ # device=self.device)
60
+ # params_bound[0, :] = -torch.inf
61
+ # params_bound[1, :] = torch.inf
62
+ # else:
63
+ # params_bound = torch.as_tensor(self.params_bound,
64
+ # dtype=self.dtype,
65
+ # device=self.device)
66
+ # assert list(params_bound.shape) == [2, self.num_params]
67
+ # self.register_buffer("params_bound", params_bound, persistent=False)
68
+
69
+
70
+ self.end_pos = None
71
+ self.end_vel = None
72
+
73
+ self.params_init = None
74
+ self.params_end = None
75
+
76
+ @property
77
+ def device(self):
78
+ return self.basis_gn.device
79
+
80
+ @property
81
+ def dtype(self):
82
+ return self.basis_gn.dtype
83
+
84
+ @property
85
+ def tau(self):
86
+ return self.basis_gn.tau
87
+
88
+ @property
89
+ def num_basis(self):
90
+ return self.basis_gn.num_basis
91
+
92
+ @property
93
+ def num_params(self):
94
+ return self.basis_gn.num_basis * self.num_dof
95
+
96
+ def clear_computation_result(self):
97
+ """
98
+ Clear runtime computation result
99
+
100
+ Returns:
101
+ None
102
+ """
103
+
104
+ self.pos = None
105
+ self.vel = None
106
+ # also reset tau?
107
+
108
+ def set_add_dim(self, add_dim: Union[list, torch.Size]):
109
+ """
110
+ Set additional batch dimension
111
+ Args:
112
+ add_dim: additional batch dimension
113
+
114
+ Returns: None
115
+
116
+ """
117
+ self.add_dim = add_dim
118
+ self.clear_computation_result()
119
+
120
+ def set_times(self, times: Union[torch.Tensor, np.ndarray]):
121
+ """
122
+ Set time points
123
+ Args:
124
+ times: time points
125
+
126
+ Returns:
127
+ None
128
+ """
129
+
130
+ # Shape of times
131
+ # [*add_dim, num_times]
132
+
133
+ self.times = torch.as_tensor(times, dtype=self.dtype,
134
+ device=self.device)
135
+ tau = times.reshape(-1)[-1]
136
+ self.basis_gn.tau.copy_(tau)
137
+ self.clear_computation_result()
138
+
139
+ def set_duration(self, duration: Optional[float], dt: float,):
140
+ """
141
+
142
+ Args:
143
+ duration: desired duration of trajectory
144
+ dt: control frequency
145
+ Returns:
146
+ None
147
+ """
148
+
149
+ # Shape of times
150
+ # [*add_dim, num_times]
151
+ dt = torch.as_tensor(dt, dtype=self.dtype, device=self.device)
152
+ times = torch.linspace(0, duration, round(duration / dt) + 1,
153
+ dtype=self.dtype, device=self.device)
154
+ times = add_expand_dim(times, list(range(len(self.add_dim))),
155
+ self.add_dim)
156
+ self.set_times(times)
157
+
158
+ def set_params(self,
159
+ params: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
160
+ """
161
+ Set MP params
162
+ Args:
163
+ params: parameters
164
+
165
+ Returns: unused parameters
166
+
167
+ """
168
+ # Shape of params
169
+ # [*add_dim, num_params]
170
+
171
+ params = torch.as_tensor(params, dtype=self.dtype, device=self.device)
172
+
173
+ # Check number of params
174
+ assert params.shape[-1] == self.num_params
175
+
176
+ # Set additional batch size
177
+ self.set_add_dim(list(params.shape[:-1]))
178
+
179
+ self.params = params[..., :self.num_params]
180
+ self.clear_computation_result()
181
+ return params[..., self.num_params:]
182
+
183
+ def update_inputs(self, times=None, params=None,
184
+ init_pos=None, init_vel=None, **kwargs):
185
+
186
+ if params is not None:
187
+ self.set_params(params)
188
+ if times is not None:
189
+ self.set_times(times)
190
+ if init_pos is not None:
191
+ self.set_initial_conditions(init_pos, init_vel, **kwargs)
192
+
193
+ end_pos = kwargs.get('end_pos', None)
194
+ end_vel = kwargs.get('end_vel', None)
195
+ if any([cond is not None for cond in [end_pos, end_vel]]):
196
+ self.set_end_condtions(end_pos, end_vel)
197
+
198
+ def set_initial_conditions(self,
199
+ init_pos: Union[torch.Tensor, np.ndarray],
200
+ init_vel: Union[torch.Tensor, np.ndarray],
201
+ **kwargs):
202
+
203
+ self.init_pos = torch.as_tensor(init_pos, dtype=self.dtype,
204
+ device=self.device)
205
+ self.init_vel = torch.as_tensor(init_vel, dtype=self.dtype,
206
+ device=self.device) if init_vel is not None else None
207
+ self.clear_computation_result()
208
+
209
+ self.params_init = self.basis_gn.compute_init_params(self.init_pos, self.init_vel)
210
+ if self.params_init is not None:
211
+ self.params_init /= self.weights_scale
212
+
213
+ def set_end_condtions(self, end_pos: Union[torch.Tensor, np.ndarray],
214
+ end_vel: Union[torch.Tensor, np.ndarray], **kwargs):
215
+ self.end_pos = \
216
+ torch.as_tensor(end_pos, device=self.device, dtype=self.dtype) \
217
+ if end_pos is not None else None
218
+ self.end_vel = \
219
+ torch.as_tensor(end_vel, device=self.device, dtype=self.dtype) \
220
+ if end_vel is not None else None
221
+
222
+ self.params_end = self.basis_gn.compute_end_params(self.end_pos, self.end_vel)
223
+ if self.params_end is not None:
224
+ self.params_end /= self.weights_scale
225
+
226
+ def get_traj_pos(self, times=None, params=None,
227
+ init_pos=None, init_vel=None, flat_shape=False, **kwargs):
228
+
229
+ self.update_inputs(times, params, init_pos, init_vel, **kwargs)
230
+
231
+ if self.pos is not None:
232
+ pos = self.pos
233
+ else:
234
+ assert self.params is not None
235
+
236
+ # Reshape params
237
+ # [*add_dim, num_dof * num_basis] -> [*add_dim, num_dof, num_basis]
238
+ params = self.params.reshape(*self.add_dim, self.num_dof, -1)
239
+ # extend params with possible init and end conditions
240
+ # shape: [*add_dim, num_dof, num_ctrlp]
241
+ if self.params_init is not None:
242
+ params = torch.cat((self.params_init, params), dim=-1)
243
+ if self.params_end is not None:
244
+ params = torch.cat((params, self.params_end), dim=-1)
245
+
246
+ # Get basis
247
+ # Shape: [*add_dim, num_times, num_ctrlp]
248
+ basis_single_dof = \
249
+ self.basis_gn.basis(self.times) * self.weights_scale
250
+
251
+ # Einsum shape: [*add_dim, num_times, num_ctrlp],
252
+ # [*add_dim, num_dof, num_ctrlp]
253
+ # -> [*add_dim, num_times, num_dof]
254
+ pos = torch.einsum('...ik,...jk->...ij', basis_single_dof, params)
255
+
256
+ self.pos = pos
257
+
258
+ if flat_shape:
259
+ # Switch axes to [*add_dim, num_dof, num_times]
260
+ pos = torch.einsum('...ji->...ij', pos)
261
+
262
+ # Reshape to [*add_dim, num_dof * num_times]
263
+ pos = pos.reshape(*self.add_dim, -1)
264
+
265
+ return pos
266
+
267
+ def get_traj_vel(self, times=None, params=None,
268
+ init_pos=None, init_vel=None, flat_shape=False, **kwargs):
269
+
270
+ self.update_inputs(times, params, init_pos, init_vel,
271
+ **kwargs)
272
+
273
+ if self.vel is not None:
274
+ vel = self.vel
275
+ else:
276
+ assert self.params is not None
277
+
278
+ # Reshape params
279
+ # [*add_dim, num_dof * num_basis] -> [*add_dim, num_dof, num_basis]
280
+ params = self.params.reshape(*self.add_dim, self.num_dof, -1)
281
+ # extend params with possible init and end conditions
282
+ # shape: [*add_dim, num_dof, num_ctrlp]
283
+ if self.params_init is not None:
284
+ params = torch.cat((self.params_init, params), dim=-1)
285
+ if self.params_end is not None:
286
+ params = torch.cat((params, self.params_end), dim=-1)
287
+
288
+ # velocity control points shape: [*add_dim, num_dof, num_ctrlp-1]
289
+ vel_ctrlp = self.basis_gn.velocity_control_points(params)
290
+ vel_ctrlp = torch.einsum("...ij,...->...ij", vel_ctrlp,
291
+ 1 / self.tau)
292
+
293
+ # vel_basis shape: [*add_dim, num_times, num_ctrlp-1]
294
+ vel_basis = self.basis_gn.vel_basis(self.times) * self.weights_scale
295
+
296
+ # Einsum shape: [*add_dim, num_times, num_ctrlp-1],
297
+ # [*add_dim, num_dof, num_ctrlp-1]
298
+ # -> [*add_dim, num_times, num_dof]
299
+ vel = torch.einsum('...ik,...jk->...ij', vel_basis, vel_ctrlp)
300
+
301
+ self.vel = vel
302
+
303
+ if flat_shape:
304
+ # Switch axes to [*add_dim, num_dof, num_times]
305
+ vel = torch.einsum('...ji->...ij', vel)
306
+
307
+ # Reshape to [*add_dim, num_dof * num_times]
308
+ vel = vel.reshape(*self.add_dim, -1)
309
+
310
+ return vel
311
+
312
+ def learn_mp_params_from_trajs(self, times: torch.Tensor,
313
+ trajs: torch.Tensor, reg=1e-5, **kwargs):
314
+
315
+ # only works for learn_tau=False, learn_delay=False. And delay=0 (or you
316
+ # need to give the initial condition by yourself)
317
+
318
+ # Shape of times
319
+ # [*add_dim, num_times]
320
+ #
321
+ # Shape of trajs:
322
+ # [*add_dim, num_times, num_dof]
323
+ #
324
+ # Shape of params:
325
+ # [*add_dim, num_dof * num_basis]
326
+
327
+ assert trajs.shape[:-1] == times.shape
328
+ assert trajs.shape[-1] == self.num_dof
329
+
330
+ times = torch.as_tensor(times, dtype=self.dtype, device=self.device)
331
+ trajs = torch.as_tensor(trajs, dtype=self.dtype, device=self.device)
332
+
333
+ # Setup stuff
334
+ self.set_add_dim(list(trajs.shape[:-2]))
335
+ self.set_times(times)
336
+ dummy_params = torch.zeros(*self.add_dim, self.num_dof, self.num_basis,
337
+ device=self.device, dtype=self.dtype)
338
+
339
+ # Get initial conditions
340
+ if self.basis_gn.init_cond_order != 0:
341
+ if any([key in kwargs.keys()
342
+ for key in [ "init_pos", "init_vel"]]):
343
+ logging.warning("uses the given initial conditions")
344
+ init_pos = kwargs.get("init_pos")
345
+ init_vel = kwargs.get("init_vel")
346
+ else:
347
+ init_pos = trajs[..., 0, :]
348
+ dt = (times[..., 1] - times[..., 0])
349
+ init_vel = torch.einsum("...i,...->...i",
350
+ torch.diff(trajs, dim=-2)[..., 0, :],
351
+ 1/dt)
352
+ self.set_initial_conditions(init_pos, init_vel)
353
+ if self.params_init is not None:
354
+ dummy_params = torch.cat([self.params_init, dummy_params],
355
+ dim=-1)
356
+
357
+ if self.basis_gn.end_cond_order != 0:
358
+ if any([key in kwargs.keys()
359
+ for key in ["end_pos", "end_vel"]]):
360
+ logging.warning("uses the given end conditions")
361
+ end_pos = kwargs.get("end_pos")
362
+ end_vel = kwargs.get("end_vel")
363
+ else:
364
+ end_pos = trajs[..., -1, :]
365
+ dt = (times[..., 1] - times[..., 0])
366
+ end_vel = torch.einsum("...i,...->...i",
367
+ torch.diff(trajs, dim=-2)[..., -1, :],
368
+ 1/dt)
369
+ self.set_end_condtions(end_pos, end_vel)
370
+ if self.params_end is not None:
371
+ dummy_params = torch.cat([dummy_params, self.params_end],
372
+ dim=-1)
373
+
374
+ basis_single_dof = self.basis_gn.basis(times) * self.weights_scale
375
+ # shape: [*add_dim, num_time, num_ctrlp]
376
+ # [*add_dim, num_dof, num_ctrlp]
377
+ # [*add_dim, num_times, num_dof]
378
+ pos_det = torch.einsum('...ik,...jk->...ij', basis_single_dof, dummy_params)
379
+ # swtich axes to [*add_dim, num_dof, num_times]
380
+ pos_det = torch.einsum('...ij->...ji', pos_det)
381
+ pos_det = pos_det.reshape(*self.add_dim, -1)
382
+
383
+ basis_multi_dofs = self.basis_gn.basis_multi_dofs(self.times, self.num_dof) * self.weights_scale
384
+ # Solve this: Aw = B -> w = A^{-1} B
385
+ # Einsum_shape: [*add_dim, num_dof * num_times, num_dof * num_basis]
386
+ # [*add_dim, num_dof * num_times, num_dof * num_basis]
387
+ # -> [*add_dim, num_dof * num_basis, num_dof * num_basis]
388
+ A = torch.einsum('...ki,...kj->...ij', basis_multi_dofs,
389
+ basis_multi_dofs)
390
+ A += torch.eye(self.num_params,
391
+ dtype=self.dtype,
392
+ device=self.device) * reg
393
+
394
+ # Swap axis and reshape: [*add_dim, num_times, num_dof]
395
+ # -> [*add_dim, num_dof, num_times]
396
+ trajs = torch.einsum("...ij->...ji", trajs)
397
+ # Reshape [*add_dim, num_dof, num_times]
398
+ # -> [*add_dim, num_dof * num_times]
399
+ trajs = trajs.reshape([*self.add_dim, -1])
400
+
401
+ # Position minus initial condition terms,
402
+ pos_w = trajs - pos_det
403
+
404
+ # Einsum_shape: [*add_dim, num_dof * num_times, num_dof * num_basis]
405
+ # [*add_dim, num_dof * num_times]
406
+ # -> [*add_dim, num_dof * num_basis]
407
+ B = torch.einsum('...ki,...k->...i', basis_multi_dofs, pos_w)
408
+
409
+ # Shape of weights: [*add_dim, num_dof * num_basis]
410
+ params = torch.linalg.solve(A, B)
411
+
412
+ self.set_params(params)
413
+
414
+ return {"params": params,
415
+ "init_pos": self.init_pos,
416
+ "init_vel": self.init_vel,
417
+ "end_pos": self.end_pos,
418
+ "end_vel": self.end_vel,
419
+ }
420
+
421
+
422
+ def add_expand_dim(data: Union[torch.Tensor, np.ndarray],
423
+ add_dim_indices: [int],
424
+ add_dim_sizes: [int]) -> Union[torch.Tensor, np.ndarray]:
425
+ """
426
+ Add additional dimensions to tensor and expand accordingly
427
+ Args:
428
+ data: tensor to be operated. Torch.Tensor or numpy.ndarray
429
+ add_dim_indices: the indices of added dimensions in the result tensor
430
+ add_dim_sizes: the expanding size of the additional dimensions
431
+
432
+ Returns:
433
+ result: result tensor after adding and expanding
434
+ """
435
+ num_data_dim = data.ndim
436
+ num_dim_to_add = len(add_dim_indices)
437
+
438
+ add_dim_reverse_indices = [num_data_dim + num_dim_to_add + idx for idx in
439
+ add_dim_indices]
440
+
441
+ str_add_dim = ""
442
+ str_expand = ""
443
+ add_dim_index = 0
444
+ for dim in range(num_data_dim + num_dim_to_add):
445
+ if dim in add_dim_indices or dim in add_dim_reverse_indices:
446
+ str_add_dim += "None, "
447
+ str_expand += str(add_dim_sizes[add_dim_index]) + ", "
448
+ add_dim_index += 1
449
+ else:
450
+ str_add_dim += ":, "
451
+ if type(data) == torch.Tensor:
452
+ str_expand += "-1, "
453
+ elif type(data) == np.ndarray:
454
+ str_expand += "1, "
455
+ else:
456
+ raise NotImplementedError
457
+
458
+ str_add_dime_eval = "data[" + str_add_dim + "]"
459
+ if type(data) == torch.Tensor:
460
+ return eval("eval(str_add_dime_eval).expand(" + str_expand + ")")
461
+ else:
462
+ return eval("np.tile(eval(str_add_dime_eval),[" + str_expand + "])")
utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+ from typing import Union
4
+
5
+ def continuous_to_discrete(tensor, min_val=None, max_val=None, num_bins=256):
6
+ """
7
+ Convert a continuous PyTorch tensor to discrete tokens in the range [0, 255].
8
+
9
+ Args:
10
+ tensor (torch.Tensor): Input tensor with continuous values.
11
+ min_val (float, optional): Minimum value for normalization. If None, use tensor.min().
12
+ max_val (float, optional): Maximum value for normalization. If None, use tensor.max().
13
+
14
+ Returns:
15
+ torch.Tensor: Discretized tensor with values in the range [0, 255].
16
+ """
17
+
18
+ if min_val is None:
19
+ min_val = tensor.min()
20
+ if max_val is None:
21
+ max_val = tensor.max()
22
+
23
+ # Normalize the tensor to [0, 1]
24
+ assert torch.all(tensor >= min_val - 1e-3), "Input tensor has values below min_val"
25
+ assert torch.all(tensor <= max_val + 1e-3), "Input tensor has values above max_val"
26
+ normalized_tensor = (tensor - min_val) / (max_val - min_val)
27
+ normalized_tensor = torch.clamp(normalized_tensor, 0, 1)
28
+
29
+ # Ensure no out-of-bound values
30
+ # Scale to [0, 255] and quantize to integers
31
+ discrete_tensor = torch.round(normalized_tensor * (num_bins-1)).to(torch.long)
32
+ return discrete_tensor
33
+
34
+
35
+ def discrete_to_continuous(discrete_tensor, min_val=0, max_val=1, num_bins=256):
36
+ """
37
+ Convert a discrete PyTorch tensor with values in the range [0, 255]
38
+ back to continuous values in the range [min_val, max_val].
39
+
40
+ Args:
41
+ discrete_tensor (torch.Tensor): Input tensor with discrete values (0 to 255).
42
+ min_val (float): Minimum value of the original continuous range.
43
+ max_val (float): Maximum value of the original continuous range.
44
+
45
+ Returns:
46
+ torch.Tensor: Continuous tensor with values in the range [min_val, max_val].
47
+ """
48
+ # Map discrete tokens to [0, 1]
49
+ # Normalize the tensor to [0, 1]
50
+ normalized_tensor = discrete_tensor.float() / (num_bins-1)
51
+
52
+ # Map normalized values to [min_val, max_val]
53
+ continuous_tensor = normalized_tensor * (max_val - min_val) + min_val
54
+
55
+ # Ensure no out-of-bound values
56
+ continuous_tensor = torch.clamp(continuous_tensor, min_val, max_val)
57
+ return continuous_tensor
58
+
59
+
60
+ def normalize_tensor(tensor, w_min, w_max, norm_min=-1.0, norm_max=1.0):
61
+ """
62
+ Normalize a tensor from its original range [w_min, w_max] to a new range [norm_min, norm_max].
63
+
64
+ Args:
65
+ tensor (torch.Tensor): Input tensor to be normalized
66
+ w_min (float): Minimum value bound of the original tensor
67
+ w_max (float): Maximum value bound of the original tensor
68
+ norm_min (float, optional): Minimum value of the normalized range. Defaults to 0.0.
69
+ norm_max (float, optional): Maximum value of the normalized range. Defaults to 1.0.
70
+
71
+ Returns:
72
+ torch.Tensor: Normalized tensor with values in range [norm_min, norm_max]
73
+ """
74
+
75
+ # Clip the input tensor to be within [w_min, w_max]
76
+ clipped_tensor = torch.clamp(tensor, w_min, w_max)
77
+
78
+ # Normalize to [0, 1] range first
79
+ normalized = (clipped_tensor - w_min) / (w_max - w_min)
80
+
81
+ # Scale to the desired [norm_min, norm_max] range
82
+ normalized = normalized * (norm_max - norm_min) + norm_min
83
+
84
+ return normalized
85
+
86
+ def denormalize_tensor(normalized_tensor, w_min, w_max, norm_min=-1.0, norm_max=1.0):
87
+ """
88
+ Denormalize a tensor from the normalized range [norm_min, norm_max] back to the original range [w_min, w_max].
89
+
90
+ Args:
91
+ normalized_tensor (torch.Tensor): Normalized input tensor
92
+ w_min (float): Minimum value bound of the original range
93
+ w_max (float): Maximum value bound of the original range
94
+ norm_min (float, optional): Minimum value of the normalized range. Defaults to 0.0.
95
+ norm_max (float, optional): Maximum value of the normalized range. Defaults to 1.0.
96
+
97
+ Returns:
98
+ torch.Tensor: Denormalized tensor with values in range [w_min, w_max]
99
+ """
100
+
101
+ # Clip the normalized tensor to be within [norm_min, norm_max]
102
+ clipped_tensor = torch.clamp(normalized_tensor, norm_min, norm_max)
103
+
104
+ # Scale from [norm_min, norm_max] to [0, 1] first
105
+ denormalized = (clipped_tensor - norm_min) / (norm_max - norm_min)
106
+
107
+ # Scale to the original [w_min, w_max] range
108
+ denormalized = denormalized * (w_max - w_min) + w_min
109
+
110
+ return denormalized
111
+
112
+
113
+ def tensor_linspace(start: Union[float, int, torch.Tensor],
114
+ end: Union[float, int, torch.Tensor],
115
+ steps: int) -> torch.Tensor:
116
+ """
117
+ Vectorized version of torch.linspace.
118
+ Modified from:
119
+ https://github.com/zhaobozb/layout2im/blob/master/models/bilinear.py#L246
120
+
121
+ Args:
122
+ start: start value, scalar or tensor
123
+ end: end value, scalar or tensor
124
+ steps: num of steps
125
+
126
+ Returns:
127
+ linspace tensor
128
+ """
129
+ # Shape of start:
130
+ # [*add_dim, dim_data] or a scalar
131
+ #
132
+ # Shape of end:
133
+ # [*add_dim, dim_data] or a scalar
134
+ #
135
+ # Shape of out:
136
+ # [*add_dim, steps, dim_data]
137
+
138
+ # - out: Tensor of shape start.size() + (steps,), such that
139
+ # out.select(-1, 0) == start, out.select(-1, -1) == end,
140
+ # and the other elements of out linearly interpolate between
141
+ # start and end.
142
+
143
+ if isinstance(start, torch.Tensor) and not isinstance(end, torch.Tensor):
144
+ end += torch.zeros_like(start)
145
+ elif not isinstance(start, torch.Tensor) and isinstance(end, torch.Tensor):
146
+ start += torch.zeros_like(end)
147
+ elif isinstance(start, torch.Tensor) and isinstance(end, torch.Tensor):
148
+ assert start.size() == end.size()
149
+ else:
150
+ return torch.linspace(start, end, steps)
151
+
152
+ view_size = start.size() + (1,)
153
+ w_size = (1,) * start.dim() + (steps,)
154
+ out_size = start.size() + (steps,)
155
+
156
+ start_w = torch.linspace(1, 0, steps=steps).to(start)
157
+ start_w = start_w.view(w_size).expand(out_size)
158
+ end_w = torch.linspace(0, 1, steps=steps).to(start)
159
+ end_w = end_w.view(w_size).expand(out_size)
160
+
161
+ start = start.contiguous().view(view_size).expand(out_size)
162
+ end = end.contiguous().view(view_size).expand(out_size)
163
+
164
+ out = start_w * start + end_w * end
165
+ out = torch.einsum('...ji->...ij', out)
166
+ return out