Upload folder using huggingface_hub
Browse files- online_bspline_tokenizer.py +53 -10
- processor_config.json +1 -0
online_bspline_tokenizer.py
CHANGED
|
@@ -24,6 +24,36 @@ from functools import wraps
|
|
| 24 |
from transformers.processing_utils import ProcessorMixin
|
| 25 |
from scipy.interpolate import BSpline, make_lsq_spline
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
# ============================================================================
|
| 28 |
# Utility Functions
|
| 29 |
# ============================================================================
|
|
@@ -216,13 +246,21 @@ class BestBSpline:
|
|
| 216 |
joint_dof: Number of joint DOFs
|
| 217 |
gripper_dof: Number of gripper DOFs (从后往前数)
|
| 218 |
check_step: Downsampling step for constraint checking acceleration
|
|
|
|
| 219 |
"""
|
| 220 |
-
def __init__(self, degree: int = 3, joint_dof: int = 6, gripper_dof: int = 1, check_step=1, use_gurobi: bool = False):
|
| 221 |
self.degree = degree
|
| 222 |
self.joint_dof = joint_dof
|
| 223 |
self.gripper_dof = gripper_dof
|
| 224 |
self.check_step = check_step # 降采样步长,用于加速
|
| 225 |
self.use_gurobi = use_gurobi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
def _extract_gripper_forced_knots(self, gripper_traj: np.ndarray, time_points: np.ndarray) -> List[int]:
|
| 228 |
"""
|
|
@@ -400,15 +438,18 @@ class BestBSpline:
|
|
| 400 |
def batch_compress(self, batch_trajectory: torch.Tensor):
|
| 401 |
batch = batch_trajectory.shape[0]
|
| 402 |
cpu_cores_per_fit = max(1, mp.cpu_count() // batch)
|
| 403 |
-
# with mp.Pool(processes=batch) as pool:
|
| 404 |
-
# tasks = [(batch_trajectory[i], cpu_cores_per_fit) for i in range(batch)]
|
| 405 |
-
# results = pool.starmap(self.fit, tasks)
|
| 406 |
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
|
| 413 |
batch_knots = [res[0] for res in results]
|
| 414 |
batch_control_points = [res[1] for res in results]
|
|
@@ -477,6 +518,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 477 |
degree: B-spline degree (default 3 = cubic)
|
| 478 |
gripper_zero_order: Use separate zero-order spline for gripper (True/False)
|
| 479 |
gripper_dof: Number of gripper DOFs (if gripper_zero_order is True)
|
|
|
|
| 480 |
|
| 481 |
Example:
|
| 482 |
>>> tokenizer = BestTokenizer(num_dof=7, num_basis=10, seq_len=50)
|
|
@@ -499,7 +541,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 499 |
|
| 500 |
def __init__(self, num_dof: int = 7, in_seq_len: int = 10, out_seq_len: int = 5,
|
| 501 |
vocab_size: int = 256, degree: int = 3, gripper_dof: int = 1,
|
| 502 |
-
do_pad: bool = True, use_gurobi: bool = False, device: str = "cuda"):
|
| 503 |
super().__init__()
|
| 504 |
self.in_seq_len = in_seq_len
|
| 505 |
self.out_seq_len = out_seq_len
|
|
@@ -516,6 +558,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 516 |
joint_dof=self.joint_dof,
|
| 517 |
gripper_dof=self.gripper_dof,
|
| 518 |
use_gurobi=use_gurobi,
|
|
|
|
| 519 |
)
|
| 520 |
|
| 521 |
# Initialize weight bounds for normalization
|
|
|
|
| 24 |
from transformers.processing_utils import ProcessorMixin
|
| 25 |
from scipy.interpolate import BSpline, make_lsq_spline
|
| 26 |
|
| 27 |
+
# ============================================================================
|
| 28 |
+
# Process Pool Manager (Singleton)
|
| 29 |
+
# ============================================================================
|
| 30 |
+
class ProcessPoolManager:
|
| 31 |
+
"""Singleton process pool manager to avoid repeated pool creation/destruction."""
|
| 32 |
+
_instance = None
|
| 33 |
+
_pool = None
|
| 34 |
+
_lock = mp.Lock()
|
| 35 |
+
|
| 36 |
+
def __new__(cls):
|
| 37 |
+
if cls._instance is None:
|
| 38 |
+
cls._instance = super().__new__(cls)
|
| 39 |
+
return cls._instance
|
| 40 |
+
|
| 41 |
+
def get_pool(self, processes=None):
|
| 42 |
+
"""Get or create the process pool."""
|
| 43 |
+
with self._lock:
|
| 44 |
+
if self._pool is None:
|
| 45 |
+
processes = processes or mp.cpu_count()
|
| 46 |
+
self._pool = mp.Pool(processes=processes)
|
| 47 |
+
return self._pool
|
| 48 |
+
|
| 49 |
+
def close_pool(self):
|
| 50 |
+
"""Close and cleanup the process pool."""
|
| 51 |
+
with self._lock:
|
| 52 |
+
if self._pool is not None:
|
| 53 |
+
self._pool.close()
|
| 54 |
+
self._pool.join()
|
| 55 |
+
self._pool = None
|
| 56 |
+
|
| 57 |
# ============================================================================
|
| 58 |
# Utility Functions
|
| 59 |
# ============================================================================
|
|
|
|
| 246 |
joint_dof: Number of joint DOFs
|
| 247 |
gripper_dof: Number of gripper DOFs (从后往前数)
|
| 248 |
check_step: Downsampling step for constraint checking acceleration
|
| 249 |
+
is_multi_process: Whether to use multiprocessing for batch compression (default: False)
|
| 250 |
"""
|
| 251 |
+
def __init__(self, degree: int = 3, joint_dof: int = 6, gripper_dof: int = 1, check_step=1, use_gurobi: bool = False, is_multi_process: bool = False):
|
| 252 |
self.degree = degree
|
| 253 |
self.joint_dof = joint_dof
|
| 254 |
self.gripper_dof = gripper_dof
|
| 255 |
self.check_step = check_step # 降采样步长,用于加速
|
| 256 |
self.use_gurobi = use_gurobi
|
| 257 |
+
self.is_multi_process = is_multi_process
|
| 258 |
+
self.pool_manager = ProcessPoolManager() if is_multi_process else None
|
| 259 |
+
|
| 260 |
+
def __del__(self):
|
| 261 |
+
"""Cleanup process pool when object is destroyed."""
|
| 262 |
+
if self.is_multi_process and self.pool_manager is not None:
|
| 263 |
+
self.pool_manager.close_pool()
|
| 264 |
|
| 265 |
def _extract_gripper_forced_knots(self, gripper_traj: np.ndarray, time_points: np.ndarray) -> List[int]:
|
| 266 |
"""
|
|
|
|
| 438 |
def batch_compress(self, batch_trajectory: torch.Tensor):
|
| 439 |
batch = batch_trajectory.shape[0]
|
| 440 |
cpu_cores_per_fit = max(1, mp.cpu_count() // batch)
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
+
if self.is_multi_process and batch > 1:
|
| 443 |
+
# Use singleton process pool for parallel processing
|
| 444 |
+
pool = self.pool_manager.get_pool(processes=batch)
|
| 445 |
+
tasks = [(batch_trajectory[i], cpu_cores_per_fit) for i in range(batch)]
|
| 446 |
+
results = pool.starmap(self.fit, tasks)
|
| 447 |
+
else:
|
| 448 |
+
# Single-threaded processing
|
| 449 |
+
results = []
|
| 450 |
+
for i in batch_trajectory:
|
| 451 |
+
res = self.fit(i, cpu_cores=cpu_cores_per_fit)
|
| 452 |
+
results.append(res)
|
| 453 |
|
| 454 |
batch_knots = [res[0] for res in results]
|
| 455 |
batch_control_points = [res[1] for res in results]
|
|
|
|
| 518 |
degree: B-spline degree (default 3 = cubic)
|
| 519 |
gripper_zero_order: Use separate zero-order spline for gripper (True/False)
|
| 520 |
gripper_dof: Number of gripper DOFs (if gripper_zero_order is True)
|
| 521 |
+
is_multi_process: Whether to use multiprocessing for batch compression (default: False)
|
| 522 |
|
| 523 |
Example:
|
| 524 |
>>> tokenizer = BestTokenizer(num_dof=7, num_basis=10, seq_len=50)
|
|
|
|
| 541 |
|
| 542 |
def __init__(self, num_dof: int = 7, in_seq_len: int = 10, out_seq_len: int = 5,
|
| 543 |
vocab_size: int = 256, degree: int = 3, gripper_dof: int = 1,
|
| 544 |
+
do_pad: bool = True, use_gurobi: bool = False, is_multi_process: bool = False, device: str = "cuda"):
|
| 545 |
super().__init__()
|
| 546 |
self.in_seq_len = in_seq_len
|
| 547 |
self.out_seq_len = out_seq_len
|
|
|
|
| 558 |
joint_dof=self.joint_dof,
|
| 559 |
gripper_dof=self.gripper_dof,
|
| 560 |
use_gurobi=use_gurobi,
|
| 561 |
+
is_multi_process=is_multi_process,
|
| 562 |
)
|
| 563 |
|
| 564 |
# Initialize weight bounds for normalization
|
processor_config.json
CHANGED
|
@@ -11,5 +11,6 @@
|
|
| 11 |
"gripper_dof": 1,
|
| 12 |
"do_pad": true,
|
| 13 |
"use_gurobi": false,
|
|
|
|
| 14 |
"device": "cuda"
|
| 15 |
}
|
|
|
|
| 11 |
"gripper_dof": 1,
|
| 12 |
"do_pad": true,
|
| 13 |
"use_gurobi": false,
|
| 14 |
+
"is_multi_process": false,
|
| 15 |
"device": "cuda"
|
| 16 |
}
|