Upload folder using huggingface_hub
Browse files- online_bspline_tokenizer.py +21 -6
online_bspline_tokenizer.py
CHANGED
|
@@ -241,7 +241,7 @@ class BestBSpline:
|
|
| 241 |
|
| 242 |
def fit(self, trajectory: torch.Tensor, cpu_cores: int = min(mp.cpu_count(), 16),
|
| 243 |
tol_ratio: float = 0.03, absolute_tol: Optional[float] = 0.01,
|
| 244 |
-
time_limit: int = 10) -> Tuple[List[int], List[List[float]]]:
|
| 245 |
"""
|
| 246 |
Note: libero上设置为 absolute_tol, 0.01 为tokenization 误差。
|
| 247 |
Fit B-spline to trajectory using least squares.
|
|
@@ -252,7 +252,7 @@ class BestBSpline:
|
|
| 252 |
tol_ratio: Tolerance ratio for relative fitting error (eps = d_range × tol_ratio)
|
| 253 |
absolute_tol: Absolute tolerance for fitting error. If set, overrides tol_ratio
|
| 254 |
time_limit: Time limit for MILP solver in seconds
|
| 255 |
-
|
| 256 |
Returns:
|
| 257 |
full_knots, control_points
|
| 258 |
"""
|
|
@@ -396,22 +396,37 @@ class BestBSpline:
|
|
| 396 |
if self.use_gurobi:
|
| 397 |
solver_milp.close()
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
return full_knots_joint.tolist(), control_points
|
| 400 |
|
| 401 |
|
| 402 |
-
def batch_compress(self, batch_trajectory: torch.Tensor):
|
| 403 |
batch = batch_trajectory.shape[0]
|
| 404 |
cpu_cores_per_fit = max(1, mp.cpu_count() // batch)
|
| 405 |
|
| 406 |
if self.is_multi_process and batch > 1:
|
| 407 |
with mp.Pool(processes=batch) as pool:
|
| 408 |
-
tasks = [(batch_trajectory[i], cpu_cores_per_fit) for i in range(batch)]
|
| 409 |
results = pool.starmap(self.fit, tasks)
|
| 410 |
else:
|
| 411 |
# Single-threaded processing
|
| 412 |
results = []
|
| 413 |
for i in batch_trajectory:
|
| 414 |
-
res = self.fit(i, cpu_cores=cpu_cores_per_fit)
|
| 415 |
results.append(res)
|
| 416 |
|
| 417 |
batch_knots = [res[0] for res in results]
|
|
@@ -574,7 +589,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 574 |
"""
|
| 575 |
trajs = trajs.cpu() # 确保在 CPU 上处理
|
| 576 |
# Stage 1: adaptive bspline fit
|
| 577 |
-
knots, control_points = self.bsp.batch_compress(trajs)
|
| 578 |
# print("control point shape:", len(control_points),"*", len(control_points[0]), "*", len(control_points[0][0]))
|
| 579 |
|
| 580 |
# Stage 2: arrange to params
|
|
|
|
| 241 |
|
| 242 |
def fit(self, trajectory: torch.Tensor, cpu_cores: int = min(mp.cpu_count(), 16),
|
| 243 |
tol_ratio: float = 0.03, absolute_tol: Optional[float] = 0.01,
|
| 244 |
+
time_limit: int = 10, max_length: Optional[int] = None) -> Tuple[List[int], List[List[float]]]:
|
| 245 |
"""
|
| 246 |
Note: libero上设置为 absolute_tol, 0.01 为tokenization 误差。
|
| 247 |
Fit B-spline to trajectory using least squares.
|
|
|
|
| 252 |
tol_ratio: Tolerance ratio for relative fitting error (eps = d_range × tol_ratio)
|
| 253 |
absolute_tol: Absolute tolerance for fitting error. If set, overrides tol_ratio
|
| 254 |
time_limit: Time limit for MILP solver in seconds
|
| 255 |
+
max_length: Optional maximum length for the B-spline representation
|
| 256 |
Returns:
|
| 257 |
full_knots, control_points
|
| 258 |
"""
|
|
|
|
| 396 |
if self.use_gurobi:
|
| 397 |
solver_milp.close()
|
| 398 |
|
| 399 |
+
# 如果控制点数量超过max_length,重新采样knots并拟合
|
| 400 |
+
if max_length is not None and len(control_points[0]) > max_length:
|
| 401 |
+
full_knots_joint, control_points = [], []
|
| 402 |
+
n_internal = max(0, max_length - 2 * (self.degree + 1))
|
| 403 |
+
sampled_internal = np.linspace(t0, t_end, num=n_internal + 2)[1:-1].astype(int) if n_internal > 0 else []
|
| 404 |
+
full_knots_joint = np.concatenate([np.repeat(t0, self.degree + 1), sampled_internal, np.repeat(t_end, self.degree + 1)])
|
| 405 |
+
full_knots_grip = np.concatenate([np.repeat(t0, 1), sampled_internal, np.repeat(t_end, 1)])
|
| 406 |
+
|
| 407 |
+
control_points = []
|
| 408 |
+
for d in range(self.joint_dof):
|
| 409 |
+
control_points.append(make_lsq_spline(time_points, joint_traj[:, d], full_knots_joint, k=self.degree).c.tolist())
|
| 410 |
+
if self.gripper_dof > 0:
|
| 411 |
+
for gd in range(self.gripper_dof):
|
| 412 |
+
control_points.append([float(int(round(v))) for v in make_lsq_spline(time_points, gripper_traj[:, gd], full_knots_grip, k=0).c])
|
| 413 |
+
|
| 414 |
return full_knots_joint.tolist(), control_points
|
| 415 |
|
| 416 |
|
| 417 |
+
def batch_compress(self, batch_trajectory: torch.Tensor, max_length: int) -> Tuple[List[List[int]], List[List[List[float]]]]:
|
| 418 |
batch = batch_trajectory.shape[0]
|
| 419 |
cpu_cores_per_fit = max(1, mp.cpu_count() // batch)
|
| 420 |
|
| 421 |
if self.is_multi_process and batch > 1:
|
| 422 |
with mp.Pool(processes=batch) as pool:
|
| 423 |
+
tasks = [(batch_trajectory[i], cpu_cores_per_fit, max_length) for i in range(batch)]
|
| 424 |
results = pool.starmap(self.fit, tasks)
|
| 425 |
else:
|
| 426 |
# Single-threaded processing
|
| 427 |
results = []
|
| 428 |
for i in batch_trajectory:
|
| 429 |
+
res = self.fit(i, cpu_cores=cpu_cores_per_fit, max_length=max_length)
|
| 430 |
results.append(res)
|
| 431 |
|
| 432 |
batch_knots = [res[0] for res in results]
|
|
|
|
| 589 |
"""
|
| 590 |
trajs = trajs.cpu() # 确保在 CPU 上处理
|
| 591 |
# Stage 1: adaptive bspline fit
|
| 592 |
+
knots, control_points = self.bsp.batch_compress(trajs, self.out_seq_len)
|
| 593 |
# print("control point shape:", len(control_points),"*", len(control_points[0]), "*", len(control_points[0][0]))
|
| 594 |
|
| 595 |
# Stage 2: arrange to params
|