Upload folder using huggingface_hub
Browse files
online_bspline_tokenizer.py
CHANGED
|
@@ -590,7 +590,6 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 590 |
if self.do_pad:
|
| 591 |
all_params = _pad_params(all_params, self.out_seq_len)
|
| 592 |
|
| 593 |
-
print("1. Decode continous has compeltete: (for debug stop error)")
|
| 594 |
return all_params
|
| 595 |
|
| 596 |
def decode_continuous(self, all_params: Union[torch.Tensor, List[torch.Tensor]],
|
|
@@ -629,7 +628,6 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 629 |
batch_size = all_params.shape[0] if isinstance(all_params, torch.Tensor) else len(all_params)
|
| 630 |
all_trajs = torch.zeros((batch_size, self.in_seq_len, self.num_dof), dtype=torch.float32)
|
| 631 |
|
| 632 |
-
print("1). Decode continous has compeltete: (for debug stop error)")
|
| 633 |
return all_trajs
|
| 634 |
|
| 635 |
@torch.no_grad()
|
|
@@ -653,7 +651,6 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 653 |
|
| 654 |
# Rearrange: [batch, out_seq_len, num_t_dof] -> [batch, out_seq_len * num_t_dof]
|
| 655 |
tokens = tokens.reshape(batch_size, -1).long().to(self.device)
|
| 656 |
-
print("1). Decode discrete has compeltete: (for debug stop error)")
|
| 657 |
return tokens
|
| 658 |
|
| 659 |
@torch.no_grad()
|
|
@@ -677,7 +674,6 @@ class BestTokenizer(torch.nn.Module, ProcessorMixin):
|
|
| 677 |
target_length = self.in_seq_len if target_length is None else target_length
|
| 678 |
all_trajs = self.decode_continuous(normalized_all_params, target_length)
|
| 679 |
|
| 680 |
-
print("2. Decode discrete has compeltete: (for debug stop error)")
|
| 681 |
return all_trajs.to(self.device)
|
| 682 |
|
| 683 |
if __name__ == "__main__":
|
|
|
|
| 590 |
if self.do_pad:
|
| 591 |
all_params = _pad_params(all_params, self.out_seq_len)
|
| 592 |
|
|
|
|
| 593 |
return all_params
|
| 594 |
|
| 595 |
def decode_continuous(self, all_params: Union[torch.Tensor, List[torch.Tensor]],
|
|
|
|
| 628 |
batch_size = all_params.shape[0] if isinstance(all_params, torch.Tensor) else len(all_params)
|
| 629 |
all_trajs = torch.zeros((batch_size, self.in_seq_len, self.num_dof), dtype=torch.float32)
|
| 630 |
|
|
|
|
| 631 |
return all_trajs
|
| 632 |
|
| 633 |
@torch.no_grad()
|
|
|
|
| 651 |
|
| 652 |
# Rearrange: [batch, out_seq_len, num_t_dof] -> [batch, out_seq_len * num_t_dof]
|
| 653 |
tokens = tokens.reshape(batch_size, -1).long().to(self.device)
|
|
|
|
| 654 |
return tokens
|
| 655 |
|
| 656 |
@torch.no_grad()
|
|
|
|
| 674 |
target_length = self.in_seq_len if target_length is None else target_length
|
| 675 |
all_trajs = self.decode_continuous(normalized_all_params, target_length)
|
| 676 |
|
|
|
|
| 677 |
return all_trajs.to(self.device)
|
| 678 |
|
| 679 |
if __name__ == "__main__":
|