Luka-He commited on
Commit
ce80390
·
verified ·
1 Parent(s): f9096d0

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. online_bspline_tokenizer.py +18 -5
online_bspline_tokenizer.py CHANGED
@@ -109,7 +109,7 @@ def _params_to_knots_and_control_points(
109
  batch_control_points.append(control_points)
110
 
111
  # print("batch knots:", batch_knots)
112
- # print("batch control points:", batch_control_points)
113
 
114
  return batch_knots, batch_control_points
115
 
@@ -593,7 +593,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
593
  return all_params
594
 
595
  def decode_continuous(self, all_params: Union[torch.Tensor, List[torch.Tensor]],
596
- target_length: Optional[int] = None) -> torch.Tensor:
597
  """
598
  Decode continuous normalized parameters to trajectories.
599
 
@@ -621,6 +621,19 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
621
  knots, control_points = _params_to_knots_and_control_points(
622
  all_params, gripper_dof=self.gripper_dof, degree=self.bsp.degree
623
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  # Stage 1: B-spline decode (with per-sample exception handling)
625
  all_trajs = self.bsp.batch_decompress(knots, control_points, self.in_seq_len)
626
  except Exception as e:
@@ -654,14 +667,14 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
654
  return tokens
655
 
656
  @torch.no_grad()
657
- def decode_discrete(self, tokens: torch.Tensor, target_length: Optional[int] = None) -> torch.Tensor:
658
  """
659
  Decode discrete tokens to trajectories.
660
 
661
  Args:
662
  tokens: Discrete tokens [batch, out_seq_len * num_t_dof]
663
  target_length: Target trajectory length (default: self.seq_length)
664
-
665
  Returns:
666
  Reconstructed trajectories [batch, seq_len, num_dof]
667
  """
@@ -672,7 +685,7 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
672
  normalized_all_params = _discrete_to_continuous(tokens, torch.tensor(-1.0), torch.tensor(1.0), self.vocab_size)
673
 
674
  target_length = self.in_seq_len if target_length is None else target_length
675
- all_trajs = self.decode_continuous(normalized_all_params, target_length)
676
 
677
  return all_trajs.to(self.device)
678
 
 
109
  batch_control_points.append(control_points)
110
 
111
  # print("batch knots:", batch_knots)
112
+ # print("batch control points:", batch_control_points) # [batch, dof, num_ctrl]
113
 
114
  return batch_knots, batch_control_points
115
 
 
593
  return all_params
594
 
595
  def decode_continuous(self, all_params: Union[torch.Tensor, List[torch.Tensor]],
596
+ target_length: Optional[int] = None, init_pos: torch.Tensor = None) -> torch.Tensor:
597
  """
598
  Decode continuous normalized parameters to trajectories.
599
 
 
621
  knots, control_points = _params_to_knots_and_control_points(
622
  all_params, gripper_dof=self.gripper_dof, degree=self.bsp.degree
623
  )
624
+ # Stage 1.5: init_pos的处理,替换control points的第一个点
625
+ if init_pos is not None:
626
+ init_pos_np = init_pos.cpu().numpy() # [batch, num_dof]
627
+ if init_pos_np.shape[0] != len(control_points):
628
+ raise ValueError("init_pos batch size mismatch with decoded params")
629
+ if init_pos_np.shape[1] < self.joint_dof:
630
+ raise ValueError("init_pos num_dof smaller than joint_dof")
631
+
632
+ for b_idx in range(len(control_points)):
633
+ for dof_idx in range(self.joint_dof):
634
+ if control_points[b_idx][dof_idx]:
635
+ control_points[b_idx][dof_idx][0] = float(init_pos_np[b_idx, dof_idx])
636
+
637
  # Stage 1: B-spline decode (with per-sample exception handling)
638
  all_trajs = self.bsp.batch_decompress(knots, control_points, self.in_seq_len)
639
  except Exception as e:
 
667
  return tokens
668
 
669
  @torch.no_grad()
670
+ def decode_discrete(self, tokens: torch.Tensor, target_length: Optional[int] = None, init_pos: torch.Tensor = None) -> torch.Tensor:
671
  """
672
  Decode discrete tokens to trajectories.
673
 
674
  Args:
675
  tokens: Discrete tokens [batch, out_seq_len * num_t_dof]
676
  target_length: Target trajectory length (default: self.seq_length)
677
+ init_pos: Initial position tensor [batch, num_dof] (default: None)
678
  Returns:
679
  Reconstructed trajectories [batch, seq_len, num_dof]
680
  """
 
685
  normalized_all_params = _discrete_to_continuous(tokens, torch.tensor(-1.0), torch.tensor(1.0), self.vocab_size)
686
 
687
  target_length = self.in_seq_len if target_length is None else target_length
688
+ all_trajs = self.decode_continuous(normalized_all_params, target_length, init_pos)
689
 
690
  return all_trajs.to(self.device)
691