Luka-He commited on
Commit
a504ded
·
verified ·
1 Parent(s): ce80390

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. online_bspline_tokenizer.py +21 -6
online_bspline_tokenizer.py CHANGED
@@ -241,7 +241,7 @@ class BestBSpline:
241
 
242
  def fit(self, trajectory: torch.Tensor, cpu_cores: int = min(mp.cpu_count(), 16),
243
  tol_ratio: float = 0.03, absolute_tol: Optional[float] = 0.01,
244
- time_limit: int = 10) -> Tuple[List[int], List[List[float]]]:
245
  """
246
  Note: libero上设置为 absolute_tol, 0.01 为tokenization 误差。
247
  Fit B-spline to trajectory using least squares.
@@ -252,7 +252,7 @@ class BestBSpline:
252
  tol_ratio: Tolerance ratio for relative fitting error (eps = d_range × tol_ratio)
253
  absolute_tol: Absolute tolerance for fitting error. If set, overrides tol_ratio
254
  time_limit: Time limit for MILP solver in seconds
255
-
256
  Returns:
257
  full_knots, control_points
258
  """
@@ -396,22 +396,37 @@ class BestBSpline:
396
  if self.use_gurobi:
397
  solver_milp.close()
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  return full_knots_joint.tolist(), control_points
400
 
401
 
402
- def batch_compress(self, batch_trajectory: torch.Tensor):
403
  batch = batch_trajectory.shape[0]
404
  cpu_cores_per_fit = max(1, mp.cpu_count() // batch)
405
 
406
  if self.is_multi_process and batch > 1:
407
  with mp.Pool(processes=batch) as pool:
408
- tasks = [(batch_trajectory[i], cpu_cores_per_fit) for i in range(batch)]
409
  results = pool.starmap(self.fit, tasks)
410
  else:
411
  # Single-threaded processing
412
  results = []
413
  for i in batch_trajectory:
414
- res = self.fit(i, cpu_cores=cpu_cores_per_fit)
415
  results.append(res)
416
 
417
  batch_knots = [res[0] for res in results]
@@ -574,7 +589,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
574
  """
575
  trajs = trajs.cpu() # 确保在 CPU 上处理
576
  # Stage 1: adaptive bspline fit
577
- knots, control_points = self.bsp.batch_compress(trajs)
578
  # print("control point shape:", len(control_points),"*", len(control_points[0]), "*", len(control_points[0][0]))
579
 
580
  # Stage 2: arrange to params
 
241
 
242
  def fit(self, trajectory: torch.Tensor, cpu_cores: int = min(mp.cpu_count(), 16),
243
  tol_ratio: float = 0.03, absolute_tol: Optional[float] = 0.01,
244
+ time_limit: int = 10, max_length: Optional[int] = None) -> Tuple[List[int], List[List[float]]]:
245
  """
246
  Note: libero上设置为 absolute_tol, 0.01 为tokenization 误差。
247
  Fit B-spline to trajectory using least squares.
 
252
  tol_ratio: Tolerance ratio for relative fitting error (eps = d_range × tol_ratio)
253
  absolute_tol: Absolute tolerance for fitting error. If set, overrides tol_ratio
254
  time_limit: Time limit for MILP solver in seconds
255
+ max_length: Optional maximum length for the B-spline representation
256
  Returns:
257
  full_knots, control_points
258
  """
 
396
  if self.use_gurobi:
397
  solver_milp.close()
398
 
399
+ # 如果控制点数量超过max_length,重新采样knots并拟合
400
+ if max_length is not None and len(control_points[0]) > max_length:
401
+ full_knots_joint, control_points = [], []
402
+ n_internal = max(0, max_length - 2 * (self.degree + 1))
403
+ sampled_internal = np.linspace(t0, t_end, num=n_internal + 2)[1:-1].astype(int) if n_internal > 0 else []
404
+ full_knots_joint = np.concatenate([np.repeat(t0, self.degree + 1), sampled_internal, np.repeat(t_end, self.degree + 1)])
405
+ full_knots_grip = np.concatenate([np.repeat(t0, 1), sampled_internal, np.repeat(t_end, 1)])
406
+
407
+ control_points = []
408
+ for d in range(self.joint_dof):
409
+ control_points.append(make_lsq_spline(time_points, joint_traj[:, d], full_knots_joint, k=self.degree).c.tolist())
410
+ if self.gripper_dof > 0:
411
+ for gd in range(self.gripper_dof):
412
+ control_points.append([float(int(round(v))) for v in make_lsq_spline(time_points, gripper_traj[:, gd], full_knots_grip, k=0).c])
413
+
414
  return full_knots_joint.tolist(), control_points
415
 
416
 
417
+ def batch_compress(self, batch_trajectory: torch.Tensor, max_length: int) -> Tuple[List[List[int]], List[List[List[float]]]]:
418
  batch = batch_trajectory.shape[0]
419
  cpu_cores_per_fit = max(1, mp.cpu_count() // batch)
420
 
421
  if self.is_multi_process and batch > 1:
422
  with mp.Pool(processes=batch) as pool:
423
+ tasks = [(batch_trajectory[i], cpu_cores_per_fit, max_length) for i in range(batch)]
424
  results = pool.starmap(self.fit, tasks)
425
  else:
426
  # Single-threaded processing
427
  results = []
428
  for i in batch_trajectory:
429
+ res = self.fit(i, cpu_cores=cpu_cores_per_fit, max_length=max_length)
430
  results.append(res)
431
 
432
  batch_knots = [res[0] for res in results]
 
589
  """
590
  trajs = trajs.cpu() # 确保在 CPU 上处理
591
  # Stage 1: adaptive bspline fit
592
+ knots, control_points = self.bsp.batch_compress(trajs, self.out_seq_len)
593
  # print("control point shape:", len(control_points),"*", len(control_points[0]), "*", len(control_points[0][0]))
594
 
595
  # Stage 2: arrange to params