File size: 14,222 Bytes
991941e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | import logging
import numpy as np
import sentencepiece
from transformers import AutoProcessor
import etils.epath as epath
import openpi.shared.download as download
class PaligemmaTokenizer:
def __init__(self, max_len: int = 48):
self._max_len = max_len
local_path = epath.Path("assets/paligemma_tokenizer.model")
hf_path = epath.Path("/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/VLA-Humanoid/paligemma-3b-pt-224/tokenizer.model")
if local_path.exists():
path = local_path
elif hf_path.exists():
path = hf_path
else:
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
def tokenize(self, prompt: str) -> tuple[np.ndarray, np.ndarray]:
cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ")
# tokenize "\n" separately as the "start of answer" token
tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n")
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
mask = [True] * tokens_len + padding
tokens = tokens + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
mask = [True] * self._max_len
return np.asarray(tokens), np.asarray(mask)
class FASTTokenizer:
def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"):
self._max_len = max_len
# Download base PaliGemma tokenizer
local_path = epath.Path("assets/paligemma_tokenizer.model")
hf_path = epath.Path("/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/VLA-Humanoid/paligemma-3b-pt-224/tokenizer.model")
if local_path.exists():
path = local_path
elif hf_path.exists():
path = hf_path
else:
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
# Instantiate FAST tokenizer - check for local path first
local_fast_path = epath.Path("fast")
if local_fast_path.exists():
fast_tokenizer_path = str(local_fast_path)
self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None,
dont_pad: bool = False,
dont_loss: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
# Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab
action_tokens = self._fast_tokenizer(actions[None])[0]
action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens)
# Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'
postfix_tokens = (
self._paligemma_tokenizer.encode("Action: ")
+ action_tokens_in_pg.tolist()
+ self._paligemma_tokenizer.encode("|")
)
else:
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
if dont_loss:
loss_mask = [False] * len(prefix_tokens) + [False] * len(postfix_tokens) # no loss on prefix or postfix
else:
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
# When padding is not desired
if dont_pad:
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
print(f"WARNING: No `Action: ` found in decoded tokens: {decoded_tokens}, so returning zeros")
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
return self._fast_tokenizer.decode(
[action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim
)[0]
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
class FASTTokenizerRicl:
def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast", action_horizon: int = 10, action_dim: int = 8):
self._max_len = max_len
self._action_horizon = action_horizon
self._action_dim = action_dim
# Download base PaliGemma tokenizer
local_path = epath.Path("assets/paligemma_tokenizer.model")
hf_path = epath.Path("/projects/extern/kisski/kisski-spath/dir.project/VLA_Groot/in_context_learning/VLA-Humanoid/paligemma-3b-pt-224/tokenizer.model")
if local_path.exists():
path = local_path
elif hf_path.exists():
path = hf_path
else:
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
# Instantiate FAST tokenizer - check for local path first
local_fast_path = epath.Path("fast")
if local_fast_path.exists():
fast_tokenizer_path = str(local_fast_path)
self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None,
dont_pad: bool = False,
dont_loss: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
# Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab
assert actions.shape == (self._action_horizon, self._action_dim), f"{actions.shape=}"
action_tokens = self._fast_tokenizer(actions[None])[0]
action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens)
# Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'
postfix_tokens = (
self._paligemma_tokenizer.encode("Action: ")
+ action_tokens_in_pg.tolist()
+ self._paligemma_tokenizer.encode("|")
)
else:
postfix_tokens = []
# always pad prefix tokens to 1/2 the max length
assert self._max_len % 2 == 0, "max_len must be divisible by 2 to pad prefix tokens to 1/2 the max length and postfix tokens to the rest"
if len(prefix_tokens) < self._max_len // 2:
prefix_padding = [False] * (self._max_len // 2 - len(prefix_tokens))
else:
raise ValueError(f"Prefix tokens length ({len(prefix_tokens)}) exceeds 1/2 the max length ({self._max_len // 2})! Increase the `max_token_len` in your model config.")
# pad postfix tokens if not dont_pad
if dont_pad:
postfix_padding = []
else:
postfix_padding = [False] * (self._max_len - len(prefix_tokens) - len(prefix_padding) - len(postfix_tokens))
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens_len = len(prefix_tokens) + len(prefix_padding) + len(postfix_tokens) + len(postfix_padding)
if not dont_pad:
assert tokens_len == self._max_len
token_mask = [True] * len(prefix_tokens) + [False] * len(prefix_padding) + [True] * len(postfix_tokens) + [False] * len(postfix_padding)
ar_mask = [0] * len(prefix_tokens) + [False] * len(prefix_padding) + [1] * len(postfix_tokens) + [False] * len(postfix_padding)
if dont_loss:
loss_mask = [False] * tokens_len # no loss on prefix or postfix
else:
loss_mask = [False] * len(prefix_tokens) + [False] * len(prefix_padding) + [True] * len(postfix_tokens) + [False] * len(postfix_padding) # Loss on postfix_tokens only
# pad prefix and postfix tokens
prefix_tokens = prefix_tokens + prefix_padding
postfix_tokens = postfix_tokens + postfix_padding
if len(postfix_tokens) == 0:
# happens at inference time when actions are not provided and dont_pad is True
postfix_tokens = None
else:
postfix_tokens = np.asarray(postfix_tokens)
return np.asarray(prefix_tokens), postfix_tokens, np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
assert action_horizon == self._action_horizon and action_dim == self._action_dim, f"{action_horizon=}, {action_dim=}, {self._action_horizon=}, {self._action_dim=}"
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
print(f"WARNING: No `Action: ` found in decoded tokens: {decoded_tokens}, so returning zeros")
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
print(f'decoded_tokens: {decoded_tokens}')
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
print(f'raw_action_tokens: {raw_action_tokens}')
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
print(f'action_tokens: {action_tokens}')
outputs = self._fast_tokenizer.decode(
[action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim
)
assert outputs.shape == (1, action_horizon, action_dim), f"{outputs.shape=}"
outputs = outputs[0]
print(f'outputs before normalization: {outputs}')
return outputs
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|