Luka-He commited on
Commit
ff67224
·
verified ·
1 Parent(s): fe36d9c

Upload folder using huggingface_hub

Browse files
online_bspline_tokenizer.py CHANGED
@@ -24,6 +24,36 @@ from functools import wraps
24
  from transformers.processing_utils import ProcessorMixin
25
  from scipy.interpolate import BSpline, make_lsq_spline
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # ============================================================================
28
  # Utility Functions
29
  # ============================================================================
@@ -216,13 +246,21 @@ class BestBSpline:
216
  joint_dof: Number of joint DOFs
217
  gripper_dof: Number of gripper DOFs (从后往前数)
218
  check_step: Downsampling step for constraint checking acceleration
 
219
  """
220
- def __init__(self, degree: int = 3, joint_dof: int = 6, gripper_dof: int = 1, check_step=1, use_gurobi: bool = False):
221
  self.degree = degree
222
  self.joint_dof = joint_dof
223
  self.gripper_dof = gripper_dof
224
  self.check_step = check_step # 降采样步长,用于加速
225
  self.use_gurobi = use_gurobi
 
 
 
 
 
 
 
226
 
227
  def _extract_gripper_forced_knots(self, gripper_traj: np.ndarray, time_points: np.ndarray) -> List[int]:
228
  """
@@ -400,15 +438,18 @@ class BestBSpline:
400
  def batch_compress(self, batch_trajectory: torch.Tensor):
401
  batch = batch_trajectory.shape[0]
402
  cpu_cores_per_fit = max(1, mp.cpu_count() // batch)
403
- # with mp.Pool(processes=batch) as pool:
404
- # tasks = [(batch_trajectory[i], cpu_cores_per_fit) for i in range(batch)]
405
- # results = pool.starmap(self.fit, tasks)
406
 
407
- # NOTE:单线程测试是否会出现问题
408
- results = []
409
- for i in batch_trajectory:
410
- res = self.fit(i, cpu_cores=cpu_cores_per_fit)
411
- results.append(res)
 
 
 
 
 
 
412
 
413
  batch_knots = [res[0] for res in results]
414
  batch_control_points = [res[1] for res in results]
@@ -477,6 +518,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
477
  degree: B-spline degree (default 3 = cubic)
478
  gripper_zero_order: Use separate zero-order spline for gripper (True/False)
479
  gripper_dof: Number of gripper DOFs (if gripper_zero_order is True)
 
480
 
481
  Example:
482
  >>> tokenizer = BestTokenizer(num_dof=7, num_basis=10, seq_len=50)
@@ -499,7 +541,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
499
 
500
  def __init__(self, num_dof: int = 7, in_seq_len: int = 10, out_seq_len: int = 5,
501
  vocab_size: int = 256, degree: int = 3, gripper_dof: int = 1,
502
- do_pad: bool = True, use_gurobi: bool = False, device: str = "cuda"):
503
  super().__init__()
504
  self.in_seq_len = in_seq_len
505
  self.out_seq_len = out_seq_len
@@ -516,6 +558,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
516
  joint_dof=self.joint_dof,
517
  gripper_dof=self.gripper_dof,
518
  use_gurobi=use_gurobi,
 
519
  )
520
 
521
  # Initialize weight bounds for normalization
 
24
  from transformers.processing_utils import ProcessorMixin
25
  from scipy.interpolate import BSpline, make_lsq_spline
26
 
27
+ # ============================================================================
28
+ # Process Pool Manager (Singleton)
29
+ # ============================================================================
30
+ class ProcessPoolManager:
31
+ """Singleton process pool manager to avoid repeated pool creation/destruction."""
32
+ _instance = None
33
+ _pool = None
34
+ _lock = mp.Lock()
35
+
36
+ def __new__(cls):
37
+ if cls._instance is None:
38
+ cls._instance = super().__new__(cls)
39
+ return cls._instance
40
+
41
+ def get_pool(self, processes=None):
42
+ """Get or create the process pool."""
43
+ with self._lock:
44
+ if self._pool is None:
45
+ processes = processes or mp.cpu_count()
46
+ self._pool = mp.Pool(processes=processes)
47
+ return self._pool
48
+
49
+ def close_pool(self):
50
+ """Close and cleanup the process pool."""
51
+ with self._lock:
52
+ if self._pool is not None:
53
+ self._pool.close()
54
+ self._pool.join()
55
+ self._pool = None
56
+
57
  # ============================================================================
58
  # Utility Functions
59
  # ============================================================================
 
246
  joint_dof: Number of joint DOFs
247
  gripper_dof: Number of gripper DOFs (从后往前数)
248
  check_step: Downsampling step for constraint checking acceleration
249
+ is_multi_process: Whether to use multiprocessing for batch compression (default: False)
250
  """
251
+ def __init__(self, degree: int = 3, joint_dof: int = 6, gripper_dof: int = 1, check_step=1, use_gurobi: bool = False, is_multi_process: bool = False):
252
  self.degree = degree
253
  self.joint_dof = joint_dof
254
  self.gripper_dof = gripper_dof
255
  self.check_step = check_step # 降采样步长,用于加速
256
  self.use_gurobi = use_gurobi
257
+ self.is_multi_process = is_multi_process
258
+ self.pool_manager = ProcessPoolManager() if is_multi_process else None
259
+
260
+ def __del__(self):
261
+ """Cleanup process pool when object is destroyed."""
262
+ if self.is_multi_process and self.pool_manager is not None:
263
+ self.pool_manager.close_pool()
264
 
265
  def _extract_gripper_forced_knots(self, gripper_traj: np.ndarray, time_points: np.ndarray) -> List[int]:
266
  """
 
438
  def batch_compress(self, batch_trajectory: torch.Tensor):
439
  batch = batch_trajectory.shape[0]
440
  cpu_cores_per_fit = max(1, mp.cpu_count() // batch)
 
 
 
441
 
442
+ if self.is_multi_process and batch > 1:
443
+ # Use singleton process pool for parallel processing
444
+ pool = self.pool_manager.get_pool(processes=batch)
445
+ tasks = [(batch_trajectory[i], cpu_cores_per_fit) for i in range(batch)]
446
+ results = pool.starmap(self.fit, tasks)
447
+ else:
448
+ # Single-threaded processing
449
+ results = []
450
+ for i in batch_trajectory:
451
+ res = self.fit(i, cpu_cores=cpu_cores_per_fit)
452
+ results.append(res)
453
 
454
  batch_knots = [res[0] for res in results]
455
  batch_control_points = [res[1] for res in results]
 
518
  degree: B-spline degree (default 3 = cubic)
519
  gripper_zero_order: Use separate zero-order spline for gripper (True/False)
520
  gripper_dof: Number of gripper DOFs (if gripper_zero_order is True)
521
+ is_multi_process: Whether to use multiprocessing for batch compression (default: False)
522
 
523
  Example:
524
  >>> tokenizer = BestTokenizer(num_dof=7, num_basis=10, seq_len=50)
 
541
 
542
  def __init__(self, num_dof: int = 7, in_seq_len: int = 10, out_seq_len: int = 5,
543
  vocab_size: int = 256, degree: int = 3, gripper_dof: int = 1,
544
+ do_pad: bool = True, use_gurobi: bool = False, is_multi_process: bool = False, device: str = "cuda"):
545
  super().__init__()
546
  self.in_seq_len = in_seq_len
547
  self.out_seq_len = out_seq_len
 
558
  joint_dof=self.joint_dof,
559
  gripper_dof=self.gripper_dof,
560
  use_gurobi=use_gurobi,
561
+ is_multi_process=is_multi_process,
562
  )
563
 
564
  # Initialize weight bounds for normalization
processor_config.json CHANGED
@@ -11,5 +11,6 @@
11
  "gripper_dof": 1,
12
  "do_pad": true,
13
  "use_gurobi": false,
 
14
  "device": "cuda"
15
  }
 
11
  "gripper_dof": 1,
12
  "do_pad": true,
13
  "use_gurobi": false,
14
+ "is_multi_process": false,
15
  "device": "cuda"
16
  }