Upload checkpoint-4000/action_tokenizer.py with huggingface_hub
Browse files
checkpoint-4000/action_tokenizer.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
action_tokenizer.py
|
| 3 |
+
|
| 4 |
+
Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
|
| 5 |
+
"""
|
| 6 |
+
from typing import List, Union, Dict, Optional
|
| 7 |
+
import numpy as np
|
| 8 |
+
from transformers import PreTrainedTokenizerBase
|
| 9 |
+
from scipy.stats import norm
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
ACTION_TOKEN = '<ACTION{:05d}>'
|
| 13 |
+
|
| 14 |
+
class ActionTokenizer:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 18 |
+
num_bins: int = 256,
|
| 19 |
+
min_action: int = -1,
|
| 20 |
+
max_action: int = 1,
|
| 21 |
+
):
|
| 22 |
+
self._vocab_size = num_bins
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
self.min_action, self.max_action = min_action, max_action
|
| 25 |
+
self.bin_centers = np.linspace(min_action, max_action, num_bins)
|
| 26 |
+
|
| 27 |
+
# add special action tokens to language tokenizer
|
| 28 |
+
token_list = [ACTION_TOKEN.format(i) for i in range(self._vocab_size)]
|
| 29 |
+
self.token_array = np.array(token_list)
|
| 30 |
+
|
| 31 |
+
num_new_tokens = self.tokenizer.add_tokens(token_list, special_tokens=True)
|
| 32 |
+
print(f"Add {num_new_tokens} TRANSLATION TOKENS, tokenizer vocab size {self.tokenizer.vocab_size} / {len(tokenizer)}")
|
| 33 |
+
|
| 34 |
+
self.action_token_begin_idx = self.token_start_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[0])
|
| 35 |
+
self.token_end_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[-1])
|
| 36 |
+
|
| 37 |
+
def __call__(self, action: np.ndarray) -> List[str]:
|
| 38 |
+
"""Discretize continuous actions to tokens.
|
| 39 |
+
action: np.ndarray, (n, 7), continuous actions in Cartesian or Spherical coordinates.
|
| 40 |
+
return: np.ndarray, (n, 7), tokens.
|
| 41 |
+
"""
|
| 42 |
+
action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
|
| 43 |
+
ids = np.digitize(action, self.bin_centers, right=True) # [0, 255]
|
| 44 |
+
return self.token_array[ids]
|
| 45 |
+
|
| 46 |
+
def decode_token_ids_to_actions(self, action_token_id: np.ndarray) -> np.ndarray:
|
| 47 |
+
"""decode token ids to continuous actions.
|
| 48 |
+
action_token_id: np.ndarray, (n, 7), token ids.
|
| 49 |
+
return: np.ndarray, (n, 7), continuous actions
|
| 50 |
+
"""
|
| 51 |
+
ids = action_token_id - self.action_token_begin_idx
|
| 52 |
+
ids = np.clip(ids, a_min=0, a_max=self._vocab_size - 1)
|
| 53 |
+
return self.bin_centers[ids]
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def vocab_size(self) -> int:
|
| 57 |
+
return self._vocab_size
|
| 58 |
+
|
| 59 |
+
class TranslationTokenizer:
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 63 |
+
num_bins: Dict,
|
| 64 |
+
bin_policy: Optional[Dict] = None,
|
| 65 |
+
use_spherical: bool = True,
|
| 66 |
+
):
|
| 67 |
+
self.tokenizer = tokenizer
|
| 68 |
+
self.num_theta_bins = num_bins["theta_bins"]
|
| 69 |
+
self.num_phi_bins = num_bins["phi_bins"]
|
| 70 |
+
self.num_r_bins = num_bins["r_bins"]
|
| 71 |
+
self.use_spherical = use_spherical
|
| 72 |
+
|
| 73 |
+
# for indexing
|
| 74 |
+
self.NP = self.num_phi_bins * self.num_r_bins
|
| 75 |
+
|
| 76 |
+
# add special action tokens to language tokenizer
|
| 77 |
+
self._vocab_size = self.num_theta_bins * self.num_phi_bins * self.num_r_bins
|
| 78 |
+
token_list = [ACTION_TOKEN.format(i) for i in range(self._vocab_size)]
|
| 79 |
+
self.token_array = np.array(token_list)
|
| 80 |
+
|
| 81 |
+
num_new_tokens = self.tokenizer.add_tokens(token_list, special_tokens=True)
|
| 82 |
+
print(f"Add {num_new_tokens} TRANSLATION TOKENS, tokenizer vocab size {self.tokenizer.vocab_size} / {len(tokenizer)}")
|
| 83 |
+
|
| 84 |
+
self.token_start_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[0])
|
| 85 |
+
self.token_end_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[-1])
|
| 86 |
+
self.set_bins(bin_policy)
|
| 87 |
+
|
| 88 |
+
def set_bins(self, bin_policy):
|
| 89 |
+
self.theta_bins = np.array(bin_policy["theta_bins"])
|
| 90 |
+
self.phi_bins = np.array(bin_policy["phi_bins"])
|
| 91 |
+
self.r_bins = np.array(bin_policy["r_bins"])
|
| 92 |
+
|
| 93 |
+
def cartesian_to_spherical(self, x, y, z):
|
| 94 |
+
theta = np.arctan2(np.sqrt(x**2 + y**2), z) # polar angle
|
| 95 |
+
phi = np.arctan2(y, x) # azimuthal angle
|
| 96 |
+
r = np.sqrt(x**2 + y**2 + z**2)
|
| 97 |
+
return theta, phi, r
|
| 98 |
+
|
| 99 |
+
def spherical_to_cartesian(self, theta, phi, r):
|
| 100 |
+
x = r * np.sin(theta) * np.cos(phi)
|
| 101 |
+
y = r * np.sin(theta) * np.sin(phi)
|
| 102 |
+
z = r * np.cos(theta)
|
| 103 |
+
return x, y, z
|
| 104 |
+
|
| 105 |
+
def __call__(self, action: np.ndarray) -> List[str]:
|
| 106 |
+
"""Discretize continuous actions to tokens.
|
| 107 |
+
action: np.ndarray, (n, 3), continuous actions in Cartesian or Spherical coordinates.
|
| 108 |
+
return: np.ndarray, (n,), tokens.
|
| 109 |
+
"""
|
| 110 |
+
if self.use_spherical:
|
| 111 |
+
theta, phi, r = self.cartesian_to_spherical(action[:, 0], action[:, 1], action[:, 2])
|
| 112 |
+
else:
|
| 113 |
+
theta, phi, r = action[:, 0], action[:, 1], action[:, 2]
|
| 114 |
+
|
| 115 |
+
disc_theta = np.digitize(theta, self.theta_bins[1:-1]) # b
|
| 116 |
+
disc_phi = np.digitize(phi, self.phi_bins[1:-1])
|
| 117 |
+
disc_r = np.digitize(r, self.r_bins[1:-1])
|
| 118 |
+
ids = disc_theta * self.NP + disc_phi * self.num_r_bins + disc_r
|
| 119 |
+
return self.token_array[ids]
|
| 120 |
+
|
| 121 |
+
def decode_token_ids_to_actions(self, action_token_id: np.ndarray) -> np.ndarray:
|
| 122 |
+
"""decode token ids to continuous actions.
|
| 123 |
+
action_token_id: np.ndarray, (n,), token ids.
|
| 124 |
+
return: np.ndarray, (n, 3), continuous actions
|
| 125 |
+
"""
|
| 126 |
+
action_token_id = np.clip(action_token_id, self.token_start_idx, self.token_end_idx)
|
| 127 |
+
ids = action_token_id - self.token_start_idx
|
| 128 |
+
disc_theta, disc_phi, disc_r = ids // self.NP, (ids % self.NP) // self.num_r_bins, ids % self.num_r_bins
|
| 129 |
+
|
| 130 |
+
theta = 0.5 * (self.theta_bins[disc_theta] + self.theta_bins[disc_theta + 1])
|
| 131 |
+
phi = 0.5 * (self.phi_bins[disc_phi] + self.phi_bins[disc_phi + 1])
|
| 132 |
+
r = 0.5 * (self.r_bins[disc_r] + self.r_bins[disc_r + 1])
|
| 133 |
+
|
| 134 |
+
# clip action to [-1, 1], due to the spherical coordinate action space is the circumscribed sphere of the Cartesian action space.
|
| 135 |
+
x, y, z = self.spherical_to_cartesian(theta, phi, r) if self.use_spherical else (theta, phi, r)
|
| 136 |
+
x, y, z = np.clip([x, y, z], -1, 1)
|
| 137 |
+
return np.stack((x, y, z), axis=1)
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def vocab_size(self) -> int:
|
| 141 |
+
return self._vocab_size
|
| 142 |
+
|
| 143 |
+
class RotationTokenizer:
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 147 |
+
num_bins: Dict,
|
| 148 |
+
bin_policy: Optional[Dict] = None,
|
| 149 |
+
array_begin_idx=None,
|
| 150 |
+
):
|
| 151 |
+
self.tokenizer = tokenizer
|
| 152 |
+
self.num_roll_bins = num_bins["roll_bins"] # M
|
| 153 |
+
self.num_pitch_bins = num_bins["pitch_bins"] # N
|
| 154 |
+
self.num_yaw_bins = num_bins["yaw_bins"] # P
|
| 155 |
+
self.array_begin_idx = array_begin_idx
|
| 156 |
+
|
| 157 |
+
# for indexing
|
| 158 |
+
self.NP = self.num_pitch_bins * self.num_yaw_bins
|
| 159 |
+
|
| 160 |
+
# add special action tokens to language tokenizer
|
| 161 |
+
self._vocab_size = self.num_roll_bins * self.num_pitch_bins * self.num_yaw_bins
|
| 162 |
+
token_list = [ACTION_TOKEN.format(i + self.array_begin_idx) for i in range(self._vocab_size)]
|
| 163 |
+
self.token_array = np.array(token_list)
|
| 164 |
+
|
| 165 |
+
num_new_tokens = self.tokenizer.add_tokens(token_list, special_tokens=True)
|
| 166 |
+
print(f"Add {num_new_tokens} ROTATION TOKENS to tokenizer, tokenizer vocab size {self.tokenizer.vocab_size} / {len(tokenizer)}")
|
| 167 |
+
|
| 168 |
+
self.token_start_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[0])
|
| 169 |
+
self.token_end_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[-1])
|
| 170 |
+
self.set_bins(bin_policy)
|
| 171 |
+
|
| 172 |
+
def set_bins(self, bin_policy):
|
| 173 |
+
self.roll_bins = np.array(bin_policy["roll_bins"])
|
| 174 |
+
self.pitch_bins = np.array(bin_policy["pitch_bins"])
|
| 175 |
+
self.yaw_bins = np.array(bin_policy["yaw_bins"])
|
| 176 |
+
|
| 177 |
+
def __call__(self, action: np.ndarray) -> List[str]:
|
| 178 |
+
"""Discretize continuous actions to tokens.
|
| 179 |
+
action: np.ndarray, (n, 3), continuous actions in Cartesian or Spherical coordinates.
|
| 180 |
+
return: np.ndarray, (n,), tokens.
|
| 181 |
+
"""
|
| 182 |
+
roll, pitch, yaw = action[:, 0], action[:, 1], action[:, 2]
|
| 183 |
+
disc_roll = np.clip(np.digitize(roll, self.roll_bins) - 1, 0, self.num_roll_bins - 1)
|
| 184 |
+
disc_pitch = np.clip(np.digitize(pitch, self.pitch_bins) - 1, 0, self.num_pitch_bins - 1)
|
| 185 |
+
disc_yaw = np.clip(np.digitize(yaw, self.yaw_bins) - 1, 0, self.num_yaw_bins - 1)
|
| 186 |
+
|
| 187 |
+
ids = disc_roll * self.NP + disc_pitch * self.num_yaw_bins + disc_yaw
|
| 188 |
+
return self.token_array[ids]
|
| 189 |
+
|
| 190 |
+
def decode_token_ids_to_actions(self, action_token_id: Union[np.int64, np.ndarray]) -> np.ndarray:
|
| 191 |
+
"""decode token ids to continuous actions.
|
| 192 |
+
action_token_id: np.ndarray, (n,), token ids.
|
| 193 |
+
return: np.ndarray, (n, 3), continuous actions
|
| 194 |
+
"""
|
| 195 |
+
action_token_id = np.clip(action_token_id, a_min=self.token_start_idx, a_max=self.token_end_idx)
|
| 196 |
+
ids = action_token_id - self.token_start_idx
|
| 197 |
+
disc_roll, disc_pitch, disc_yaw = ids // self.NP, (ids % self.NP) // self.num_yaw_bins, ids % self.num_yaw_bins
|
| 198 |
+
|
| 199 |
+
roll = 0.5 * (self.roll_bins[disc_roll] + self.roll_bins[disc_roll + 1])
|
| 200 |
+
pitch = 0.5 * (self.pitch_bins[disc_pitch] + self.pitch_bins[disc_pitch + 1])
|
| 201 |
+
yaw = 0.5 * (self.yaw_bins[disc_yaw] + self.yaw_bins[disc_yaw + 1])
|
| 202 |
+
return np.stack((roll, pitch, yaw), axis=1)
|
| 203 |
+
|
| 204 |
+
@property
|
| 205 |
+
def vocab_size(self) -> int:
|
| 206 |
+
return self._vocab_size
|
| 207 |
+
|
| 208 |
+
class GripperTokenzier:
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 212 |
+
num_bins: int = 2,
|
| 213 |
+
array_begin_idx = None,
|
| 214 |
+
) -> None:
|
| 215 |
+
self.tokenizer = tokenizer
|
| 216 |
+
self.num_bins = num_bins
|
| 217 |
+
self.array_begin_idx = array_begin_idx
|
| 218 |
+
token_list = [ACTION_TOKEN.format(i + self.array_begin_idx) for i in range(self.num_bins)]
|
| 219 |
+
self.token_array = np.array(token_list)
|
| 220 |
+
|
| 221 |
+
num_new_tokens = self.tokenizer.add_tokens(token_list, special_tokens=True)
|
| 222 |
+
print(f"Add {num_new_tokens} GRIPPER TOKENS to tokenizer, tokenizer vocab size {self.tokenizer.vocab_size} / {len(tokenizer)}")
|
| 223 |
+
|
| 224 |
+
self.token_start_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[0])
|
| 225 |
+
self.token_end_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[-1])
|
| 226 |
+
|
| 227 |
+
def __call__(self, action: np.ndarray) -> List[str]:
|
| 228 |
+
"""Discretize continuous actions to tokens.
|
| 229 |
+
action: np.ndarray, (n,), continuous actions in Cartesian or Spherical coordinates.
|
| 230 |
+
return: np.ndarray, (n,), tokens.
|
| 231 |
+
"""
|
| 232 |
+
ids = np.where(action >= 0.5, 1, 0)
|
| 233 |
+
return self.token_array[ids]
|
| 234 |
+
|
| 235 |
+
def decode_token_ids_to_actions(self, action_token_id: np.ndarray) -> np.ndarray:
|
| 236 |
+
"""decode token ids to continuous actions.
|
| 237 |
+
action_token_id: np.ndarray, (n,), token ids.
|
| 238 |
+
return: np.ndarray, (n, 1), continuous actions
|
| 239 |
+
"""
|
| 240 |
+
action_token_id = np.clip(action_token_id, self.token_start_idx, self.token_end_idx)
|
| 241 |
+
ids = action_token_id - self.token_start_idx
|
| 242 |
+
actions = np.where(ids == 0, 0., 1.)
|
| 243 |
+
return actions[:, None]
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def vocab_size(self) -> int:
|
| 247 |
+
return self.num_bins
|
| 248 |
+
|
| 249 |
+
class SpatialActionTokenizer:
|
| 250 |
+
range_bins = {
|
| 251 |
+
"translation": {
|
| 252 |
+
"theta_bins": (0.0, np.pi),
|
| 253 |
+
"phi_bins": (-np.pi, np.pi),
|
| 254 |
+
"r_bins": (0.0, np.sqrt(3)),
|
| 255 |
+
},
|
| 256 |
+
"rotation": {
|
| 257 |
+
"roll_bins": (-1.0, 1.0),
|
| 258 |
+
"pitch_bins": (-1.0, 1.0),
|
| 259 |
+
"yaw_bins": (-1.0, 1.0),
|
| 260 |
+
},
|
| 261 |
+
}
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 265 |
+
num_bins: Dict,
|
| 266 |
+
gs_params: Dict = None,
|
| 267 |
+
bin_policy: Dict = None,
|
| 268 |
+
use_spherical: bool = True,
|
| 269 |
+
min_sigma: float = 0.0,
|
| 270 |
+
min_action: float = -1.0,
|
| 271 |
+
max_action: float = 1.0,
|
| 272 |
+
):
|
| 273 |
+
"""set bin_policy if exist, otherwise, caculate bin_policy from gs_params or use uniform bin grids.
|
| 274 |
+
gs_params: Optional[Dict],
|
| 275 |
+
bin_policy: Optional[Dict],
|
| 276 |
+
"""
|
| 277 |
+
self.tokenizer = tokenizer
|
| 278 |
+
self.min_action, self.max_action = min_action, max_action
|
| 279 |
+
self.num_bins = num_bins
|
| 280 |
+
self.min_sigma = min_sigma
|
| 281 |
+
|
| 282 |
+
# set bin policy
|
| 283 |
+
self.bin_policy = bin_policy if bin_policy else self.get_bin_policy(gs_params, self.min_sigma)
|
| 284 |
+
self.translation_tokenizer = TranslationTokenizer(
|
| 285 |
+
self.tokenizer,
|
| 286 |
+
self.num_bins["translation"],
|
| 287 |
+
self.bin_policy["translation"],
|
| 288 |
+
use_spherical=use_spherical
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
self.rotation_tokenizer = RotationTokenizer(
|
| 292 |
+
self.tokenizer,
|
| 293 |
+
self.num_bins["rotation"],
|
| 294 |
+
self.bin_policy["rotation"],
|
| 295 |
+
array_begin_idx=self.translation_tokenizer.vocab_size,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.gripper_tokenizer = GripperTokenzier(
|
| 299 |
+
self.tokenizer,
|
| 300 |
+
self.num_bins["gripper"],
|
| 301 |
+
array_begin_idx=self.translation_tokenizer.vocab_size + self.rotation_tokenizer.vocab_size
|
| 302 |
+
)
|
| 303 |
+
self._vocab_size = self.translation_tokenizer.vocab_size + self.rotation_tokenizer.vocab_size + self.gripper_tokenizer.vocab_size
|
| 304 |
+
|
| 305 |
+
def __call__(self, action: np.ndarray) -> np.ndarray:
|
| 306 |
+
"""Discretize continuous actions to tokens.
|
| 307 |
+
action: np.ndarray, (n, 7), continuous actions in Cartesian coordinates.
|
| 308 |
+
return: np.ndarray, (n, 3), tokens.
|
| 309 |
+
"""
|
| 310 |
+
if len(action.shape) == 1:
|
| 311 |
+
assert action.shape[0] == 7, f"action dim mismatch, got action shape: {action.shape}"
|
| 312 |
+
action = action.reshape(1, 7)
|
| 313 |
+
assert action.shape[1] == 7, f"action dim mismatch, got action shape: {action.shape}"
|
| 314 |
+
|
| 315 |
+
action = np.clip(action, a_min=self.min_action, a_max=self.max_action)
|
| 316 |
+
trans_tokens = self.translation_tokenizer(action[:, :3]) # (n,)
|
| 317 |
+
rot_tokens = self.rotation_tokenizer(action[:, 3:6]) # (n,)
|
| 318 |
+
grip_tokens = self.gripper_tokenizer(action[:, 6]) # (n,)
|
| 319 |
+
return np.stack((trans_tokens, rot_tokens, grip_tokens), axis=1) # (n, 3)
|
| 320 |
+
|
| 321 |
+
def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray:
|
| 322 |
+
"""decode token ids to continuous actions.
|
| 323 |
+
action_token_ids: np.ndarray, (n, 3), token ids.
|
| 324 |
+
"""
|
| 325 |
+
if len(action_token_ids.shape) == 1:
|
| 326 |
+
assert action_token_ids.shape[0] == 3, f"action token id numbers mismatich, need 3 got {action_token_ids.shape[0]}"
|
| 327 |
+
action_token_ids = action_token_ids.reshape(1, 3)
|
| 328 |
+
assert action_token_ids.shape[1] == 3, f"token id numbers mismatich, need 3 got {action_token_ids.shape[1]}"
|
| 329 |
+
|
| 330 |
+
trans_action = self.translation_tokenizer.decode_token_ids_to_actions(action_token_ids[:, 0]) # (n, 3)
|
| 331 |
+
rot_action = self.rotation_tokenizer.decode_token_ids_to_actions(action_token_ids[:, 1]) # (n, 3)
|
| 332 |
+
grip_action = self.gripper_tokenizer.decode_token_ids_to_actions(action_token_ids[:, 2]) # (n, 1)
|
| 333 |
+
return np.concatenate((trans_action, rot_action, grip_action), axis=1) # (n, 7)
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
def vocab_size(self) -> int:
|
| 337 |
+
return self._vocab_size
|
| 338 |
+
|
| 339 |
+
@property
|
| 340 |
+
def action_token_begin_idx(self) -> int:
|
| 341 |
+
return self.translation_tokenizer.token_start_idx
|
| 342 |
+
|
| 343 |
+
def get_bin_policy(self, gs_params=None, min_sigma=0.0):
|
| 344 |
+
bin_policy = {
|
| 345 |
+
"translation": {"theta_bins": None, "phi_bins": None, "r_bins": None},
|
| 346 |
+
"rotation": {"roll_bins": None, "pitch_bins": None, "yaw_bins": None}
|
| 347 |
+
}
|
| 348 |
+
if gs_params is None:
|
| 349 |
+
for bin_type in self.range_bins.keys():
|
| 350 |
+
for bin_key in self.range_bins[bin_type].keys():
|
| 351 |
+
bin_policy[bin_type][bin_key] = np.linspace(*self.range_bins[bin_type][bin_key], self.num_bins[bin_type][bin_key] + 1)
|
| 352 |
+
print(f"use unifrom bin grids ... \n{bin_policy}")
|
| 353 |
+
else:
|
| 354 |
+
for bin_type in self.range_bins.keys():
|
| 355 |
+
for bin_key in self.range_bins[bin_type].keys():
|
| 356 |
+
mu = gs_params[bin_key.split("_")[0].lower()]["mu"]
|
| 357 |
+
sigma = max(gs_params[bin_key.split("_")[0].lower()]["sigma"], min_sigma)
|
| 358 |
+
bin_bound_prob = np.linspace(
|
| 359 |
+
norm.cdf(self.range_bins[bin_type][bin_key][0], loc=mu, scale=sigma),
|
| 360 |
+
norm.cdf(self.range_bins[bin_type][bin_key][1], loc=mu, scale=sigma),
|
| 361 |
+
self.num_bins[bin_type][bin_key] + 1,
|
| 362 |
+
)
|
| 363 |
+
bin_boundary = norm.ppf(bin_bound_prob, loc=mu, scale=sigma)
|
| 364 |
+
bin_policy[bin_type][bin_key] = np.clip(
|
| 365 |
+
bin_boundary,
|
| 366 |
+
self.range_bins[bin_type][bin_key][0],
|
| 367 |
+
self.range_bins[bin_type][bin_key][1],
|
| 368 |
+
).tolist() # for serialize
|
| 369 |
+
print(f"caculate bin grids from gaussians \n{bin_policy}")
|
| 370 |
+
return bin_policy
|
| 371 |
+
|
| 372 |
+
def get_norm_meshgrid(self, bin_policy):
|
| 373 |
+
grids = []
|
| 374 |
+
policy = {k1: {k2: np.array(v2) for k2, v2 in v1.items()} for k1, v1 in bin_policy.items()}
|
| 375 |
+
# NOTE: use unify k,v order of range_bins (tpr, rpy)
|
| 376 |
+
for bin_type in self.range_bins.keys():
|
| 377 |
+
bounds = []
|
| 378 |
+
for bin_key in self.range_bins[bin_type].keys():
|
| 379 |
+
minb, maxb = self.range_bins[bin_type][bin_key][0], self.range_bins[bin_type][bin_key][1]
|
| 380 |
+
bin_boundary = policy[bin_type][bin_key]
|
| 381 |
+
bin_center = (bin_boundary[:-1] + bin_boundary[1:]) / 2
|
| 382 |
+
bin_center = np.concatenate([np.array([minb]),bin_center,np.array([maxb])]) # padding
|
| 383 |
+
bin_center = (bin_center - minb) / (maxb - minb) # nomalize (m, n, k)
|
| 384 |
+
bounds.append(bin_center)
|
| 385 |
+
# generate grids
|
| 386 |
+
grid_x, grid_y, grid_z = np.meshgrid(*bounds)
|
| 387 |
+
grids += [np.stack([grid_x, grid_y, grid_z], -1).reshape(-1, 3)]
|
| 388 |
+
return grids[0], grids[1] # (N, 3)
|
| 389 |
+
|
| 390 |
+
def spatial_embedding_adaption(self, gs_params, embeddings: torch.nn.Embedding, min_sigma=0.0, adpt_feature=False):
|
| 391 |
+
"""
|
| 392 |
+
gs_params0, gs_params1: Dict
|
| 393 |
+
embeddings: tensor (S,E)
|
| 394 |
+
"""
|
| 395 |
+
from scipy.interpolate import griddata
|
| 396 |
+
new_policy = self.get_bin_policy(gs_params, min_sigma=min_sigma)
|
| 397 |
+
trans_grids0, rot_grids0 = self.get_norm_meshgrid(self.bin_policy)
|
| 398 |
+
trans_grids1, rot_grids1 = self.get_norm_meshgrid(new_policy)
|
| 399 |
+
|
| 400 |
+
print("overwrite bin policy and tokenizer bins ...")
|
| 401 |
+
self.bin_policy = new_policy
|
| 402 |
+
self.min_sigma = min_sigma
|
| 403 |
+
self.translation_tokenizer.set_bins(new_policy["translation"])
|
| 404 |
+
self.rotation_tokenizer.set_bins(new_policy["rotation"])
|
| 405 |
+
|
| 406 |
+
if adpt_feature:
|
| 407 |
+
emb_data = embeddings.weight.data # (S, e)
|
| 408 |
+
_, E = emb_data.shape
|
| 409 |
+
|
| 410 |
+
# translation
|
| 411 |
+
m, n, k = (self.num_bins["translation"][k] for k in ["theta_bins", "phi_bins", "r_bins"])
|
| 412 |
+
N = m*n*k
|
| 413 |
+
trans_emb_data = emb_data[:N,].reshape(m, n, k, -1).permute(3, 0, 1, 2) # (e, m, n, k)
|
| 414 |
+
pad_emb = torch.nn.functional.pad(trans_emb_data, (1, 1, 1, 1, 1, 1), "replicate").permute(1, 2, 3, 0).reshape(-1, E)
|
| 415 |
+
adpt_trans_emb = griddata(trans_grids0, pad_emb.float(), trans_grids1, method='linear')
|
| 416 |
+
adpt_trans_emb = adpt_trans_emb.reshape(m+2, n+2, k+2, E)[1:-1, 1:-1, 1:-1,]
|
| 417 |
+
|
| 418 |
+
# rotation
|
| 419 |
+
m1, n1, k1 = (self.num_bins["rotation"][k] for k in ["roll_bins", "pitch_bins", "yaw_bins"])
|
| 420 |
+
M = m1*n1*k1
|
| 421 |
+
rot_emb_data = emb_data[N : N + M,].reshape(m1, n1, k1, -1).permute(3, 0, 1, 2) # (e, m, n, k)
|
| 422 |
+
pad_emb = torch.nn.functional.pad(rot_emb_data, (1, 1, 1, 1, 1, 1), "replicate").permute(1, 2, 3, 0).reshape(-1, E)
|
| 423 |
+
adpt_rot_emb = griddata(rot_grids0, pad_emb.float(), rot_grids1, method='linear')
|
| 424 |
+
adpt_rot_emb = adpt_rot_emb.reshape(m1+2, n1+2, k1+2, E)[1:-1, 1:-1, 1:-1,]
|
| 425 |
+
|
| 426 |
+
# set data
|
| 427 |
+
device, dtype = embeddings.weight.data.device, embeddings.weight.data.dtype
|
| 428 |
+
embeddings.weight.data[:N] = torch.Tensor(adpt_trans_emb.reshape(-1, E), device=device).to(dtype)
|
| 429 |
+
embeddings.weight.data[N:N+M] = torch.Tensor(adpt_rot_emb.reshape(-1, E), device=device).to(dtype)
|
| 430 |
+
print("DONE! adapt spatial embedding to new gaussian distributation finished.")
|
| 431 |
+
print(embeddings.weight.data)
|