Upload folder using huggingface_hub
Browse files- online_bspline_tokenizer.py +18 -5
online_bspline_tokenizer.py
CHANGED
|
@@ -109,7 +109,7 @@ def _params_to_knots_and_control_points(
|
|
| 109 |
batch_control_points.append(control_points)
|
| 110 |
|
| 111 |
# print("batch knots:", batch_knots)
|
| 112 |
-
# print("batch control points:", batch_control_points)
|
| 113 |
|
| 114 |
return batch_knots, batch_control_points
|
| 115 |
|
|
@@ -593,7 +593,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 593 |
return all_params
|
| 594 |
|
| 595 |
def decode_continuous(self, all_params: Union[torch.Tensor, List[torch.Tensor]],
|
| 596 |
-
target_length: Optional[int] = None) -> torch.Tensor:
|
| 597 |
"""
|
| 598 |
Decode continuous normalized parameters to trajectories.
|
| 599 |
|
|
@@ -621,6 +621,19 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 621 |
knots, control_points = _params_to_knots_and_control_points(
|
| 622 |
all_params, gripper_dof=self.gripper_dof, degree=self.bsp.degree
|
| 623 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
# Stage 1: B-spline decode (with per-sample exception handling)
|
| 625 |
all_trajs = self.bsp.batch_decompress(knots, control_points, self.in_seq_len)
|
| 626 |
except Exception as e:
|
|
@@ -654,14 +667,14 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 654 |
return tokens
|
| 655 |
|
| 656 |
@torch.no_grad()
|
| 657 |
-
def decode_discrete(self, tokens: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor:
|
| 658 |
"""
|
| 659 |
Decode discrete tokens to trajectories.
|
| 660 |
|
| 661 |
Args:
|
| 662 |
tokens: Discrete tokens [batch, out_seq_len * num_t_dof]
|
| 663 |
target_length: Target trajectory length (default: self.seq_length)
|
| 664 |
-
|
| 665 |
Returns:
|
| 666 |
Reconstructed trajectories [batch, seq_len, num_dof]
|
| 667 |
"""
|
|
@@ -672,7 +685,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 672 |
normalized_all_params = _discrete_to_continuous(tokens, torch.tensor(-1.0), torch.tensor(1.0), self.vocab_size)
|
| 673 |
|
| 674 |
target_length = self.in_seq_len if target_length is None else target_length
|
| 675 |
-
all_trajs = self.decode_continuous(normalized_all_params, target_length)
|
| 676 |
|
| 677 |
return all_trajs.to(self.device)
|
| 678 |
|
|
|
|
| 109 |
batch_control_points.append(control_points)
|
| 110 |
|
| 111 |
# print("batch knots:", batch_knots)
|
| 112 |
+
# print("batch control points:", batch_control_points) # [batch, dof, num_ctrl]
|
| 113 |
|
| 114 |
return batch_knots, batch_control_points
|
| 115 |
|
|
|
|
| 593 |
return all_params
|
| 594 |
|
| 595 |
def decode_continuous(self, all_params: Union[torch.Tensor, List[torch.Tensor]],
|
| 596 |
+
target_length: Optional[int] = None, init_pos: torch.Tensor = None) -> torch.Tensor:
|
| 597 |
"""
|
| 598 |
Decode continuous normalized parameters to trajectories.
|
| 599 |
|
|
|
|
| 621 |
knots, control_points = _params_to_knots_and_control_points(
|
| 622 |
all_params, gripper_dof=self.gripper_dof, degree=self.bsp.degree
|
| 623 |
)
|
| 624 |
+
# Stage 1.5: init_pos的处理,替换control points的第一个点
|
| 625 |
+
if init_pos is not None:
|
| 626 |
+
init_pos_np = init_pos.cpu().numpy() # [batch, num_dof]
|
| 627 |
+
if init_pos_np.shape[0] != len(control_points):
|
| 628 |
+
raise ValueError("init_pos batch size mismatch with decoded params")
|
| 629 |
+
if init_pos_np.shape[1] < self.joint_dof:
|
| 630 |
+
raise ValueError("init_pos num_dof smaller than joint_dof")
|
| 631 |
+
|
| 632 |
+
for b_idx in range(len(control_points)):
|
| 633 |
+
for dof_idx in range(self.joint_dof):
|
| 634 |
+
if control_points[b_idx][dof_idx]:
|
| 635 |
+
control_points[b_idx][dof_idx][0] = float(init_pos_np[b_idx, dof_idx])
|
| 636 |
+
|
| 637 |
# Stage 1: B-spline decode (with per-sample exception handling)
|
| 638 |
all_trajs = self.bsp.batch_decompress(knots, control_points, self.in_seq_len)
|
| 639 |
except Exception as e:
|
|
|
|
| 667 |
return tokens
|
| 668 |
|
| 669 |
@torch.no_grad()
|
| 670 |
+
def decode_discrete(self, tokens: torch.Tensor, target_length: Optional[int] = None, init_pos: torch.Tensor = None) -> torch.Tensor:
|
| 671 |
"""
|
| 672 |
Decode discrete tokens to trajectories.
|
| 673 |
|
| 674 |
Args:
|
| 675 |
tokens: Discrete tokens [batch, out_seq_len * num_t_dof]
|
| 676 |
target_length: Target trajectory length (default: self.seq_length)
|
| 677 |
+
init_pos: Initial position tensor [batch, num_dof] (default: None)
|
| 678 |
Returns:
|
| 679 |
Reconstructed trajectories [batch, seq_len, num_dof]
|
| 680 |
"""
|
|
|
|
| 685 |
normalized_all_params = _discrete_to_continuous(tokens, torch.tensor(-1.0), torch.tensor(1.0), self.vocab_size)
|
| 686 |
|
| 687 |
target_length = self.in_seq_len if target_length is None else target_length
|
| 688 |
+
all_trajs = self.decode_continuous(normalized_all_params, target_length, init_pos)
|
| 689 |
|
| 690 |
return all_trajs.to(self.device)
|
| 691 |
|