|
|
"""
|
|
|
action_tokenizer.py
|
|
|
|
|
|
Extension class; wraps base LLM/VLM tokenizer with logic to discretize and tokenize continuous robot actions.
|
|
|
"""
|
|
|
from typing import List, Union, Dict, Optional
|
|
|
import numpy as np
|
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
from scipy.stats import norm
|
|
|
import torch
|
|
|
|
|
|
ACTION_TOKEN = '<ACTION{:05d}>'
|
|
|
|
|
|
class ActionTokenizer:
|
|
|
def __init__(
|
|
|
self,
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
num_bins: int = 256,
|
|
|
min_action: int = -1,
|
|
|
max_action: int = 1,
|
|
|
):
|
|
|
self._vocab_size = num_bins
|
|
|
self.tokenizer = tokenizer
|
|
|
self.min_action, self.max_action = min_action, max_action
|
|
|
self.bin_centers = np.linspace(min_action, max_action, num_bins)
|
|
|
|
|
|
|
|
|
token_list = [ACTION_TOKEN.format(i) for i in range(self._vocab_size)]
|
|
|
self.token_array = np.array(token_list)
|
|
|
|
|
|
num_new_tokens = self.tokenizer.add_tokens(token_list, special_tokens=True)
|
|
|
print(f"Add {num_new_tokens} TRANSLATION TOKENS, tokenizer vocab size {self.tokenizer.vocab_size} / {len(tokenizer)}")
|
|
|
|
|
|
self.action_token_begin_idx = self.token_start_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[0])
|
|
|
self.token_end_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[-1])
|
|
|
|
|
|
def __call__(self, action: np.ndarray) -> List[str]:
|
|
|
"""Discretize continuous actions to tokens.
|
|
|
action: np.ndarray, (n, 7), continuous actions in Cartesian or Spherical coordinates.
|
|
|
return: np.ndarray, (n, 7), tokens.
|
|
|
"""
|
|
|
action = np.clip(action, a_min=float(self.min_action), a_max=float(self.max_action))
|
|
|
ids = np.digitize(action, self.bin_centers, right=True)
|
|
|
return self.token_array[ids]
|
|
|
|
|
|
def decode_token_ids_to_actions(self, action_token_id: np.ndarray) -> np.ndarray:
|
|
|
"""decode token ids to continuous actions.
|
|
|
action_token_id: np.ndarray, (n, 7), token ids.
|
|
|
return: np.ndarray, (n, 7), continuous actions
|
|
|
"""
|
|
|
ids = action_token_id - self.action_token_begin_idx
|
|
|
ids = np.clip(ids, a_min=0, a_max=self._vocab_size - 1)
|
|
|
return self.bin_centers[ids]
|
|
|
|
|
|
@property
|
|
|
def vocab_size(self) -> int:
|
|
|
return self._vocab_size
|
|
|
|
|
|
class TranslationTokenizer:
|
|
|
def __init__(
|
|
|
self,
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
num_bins: Dict,
|
|
|
bin_policy: Optional[Dict] = None,
|
|
|
use_spherical: bool = True,
|
|
|
):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.num_theta_bins = num_bins["theta_bins"]
|
|
|
self.num_phi_bins = num_bins["phi_bins"]
|
|
|
self.num_r_bins = num_bins["r_bins"]
|
|
|
self.use_spherical = use_spherical
|
|
|
|
|
|
|
|
|
self.NP = self.num_phi_bins * self.num_r_bins
|
|
|
|
|
|
|
|
|
self._vocab_size = self.num_theta_bins * self.num_phi_bins * self.num_r_bins
|
|
|
token_list = [ACTION_TOKEN.format(i) for i in range(self._vocab_size)]
|
|
|
self.token_array = np.array(token_list)
|
|
|
|
|
|
num_new_tokens = self.tokenizer.add_tokens(token_list, special_tokens=True)
|
|
|
print(f"Add {num_new_tokens} TRANSLATION TOKENS, tokenizer vocab size {self.tokenizer.vocab_size} / {len(tokenizer)}")
|
|
|
|
|
|
self.token_start_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[0])
|
|
|
self.token_end_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[-1])
|
|
|
self.set_bins(bin_policy)
|
|
|
|
|
|
def set_bins(self, bin_policy):
|
|
|
self.theta_bins = np.array(bin_policy["theta_bins"])
|
|
|
self.phi_bins = np.array(bin_policy["phi_bins"])
|
|
|
self.r_bins = np.array(bin_policy["r_bins"])
|
|
|
|
|
|
def cartesian_to_spherical(self, x, y, z):
|
|
|
theta = np.arctan2(np.sqrt(x**2 + y**2), z)
|
|
|
phi = np.arctan2(y, x)
|
|
|
r = np.sqrt(x**2 + y**2 + z**2)
|
|
|
return theta, phi, r
|
|
|
|
|
|
def spherical_to_cartesian(self, theta, phi, r):
|
|
|
x = r * np.sin(theta) * np.cos(phi)
|
|
|
y = r * np.sin(theta) * np.sin(phi)
|
|
|
z = r * np.cos(theta)
|
|
|
return x, y, z
|
|
|
|
|
|
def __call__(self, action: np.ndarray) -> List[str]:
|
|
|
"""Discretize continuous actions to tokens.
|
|
|
action: np.ndarray, (n, 3), continuous actions in Cartesian or Spherical coordinates.
|
|
|
return: np.ndarray, (n,), tokens.
|
|
|
"""
|
|
|
if self.use_spherical:
|
|
|
theta, phi, r = self.cartesian_to_spherical(action[:, 0], action[:, 1], action[:, 2])
|
|
|
else:
|
|
|
theta, phi, r = action[:, 0], action[:, 1], action[:, 2]
|
|
|
|
|
|
disc_theta = np.digitize(theta, self.theta_bins[1:-1])
|
|
|
disc_phi = np.digitize(phi, self.phi_bins[1:-1])
|
|
|
disc_r = np.digitize(r, self.r_bins[1:-1])
|
|
|
ids = disc_theta * self.NP + disc_phi * self.num_r_bins + disc_r
|
|
|
return self.token_array[ids]
|
|
|
|
|
|
def decode_token_ids_to_actions(self, action_token_id: np.ndarray) -> np.ndarray:
|
|
|
"""decode token ids to continuous actions.
|
|
|
action_token_id: np.ndarray, (n,), token ids.
|
|
|
return: np.ndarray, (n, 3), continuous actions
|
|
|
"""
|
|
|
action_token_id = np.clip(action_token_id, self.token_start_idx, self.token_end_idx)
|
|
|
ids = action_token_id - self.token_start_idx
|
|
|
disc_theta, disc_phi, disc_r = ids // self.NP, (ids % self.NP) // self.num_r_bins, ids % self.num_r_bins
|
|
|
|
|
|
theta = 0.5 * (self.theta_bins[disc_theta] + self.theta_bins[disc_theta + 1])
|
|
|
phi = 0.5 * (self.phi_bins[disc_phi] + self.phi_bins[disc_phi + 1])
|
|
|
r = 0.5 * (self.r_bins[disc_r] + self.r_bins[disc_r + 1])
|
|
|
|
|
|
|
|
|
x, y, z = self.spherical_to_cartesian(theta, phi, r) if self.use_spherical else (theta, phi, r)
|
|
|
x, y, z = np.clip([x, y, z], -1, 1)
|
|
|
return np.stack((x, y, z), axis=1)
|
|
|
|
|
|
@property
|
|
|
def vocab_size(self) -> int:
|
|
|
return self._vocab_size
|
|
|
|
|
|
class RotationTokenizer:
|
|
|
def __init__(
|
|
|
self,
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
num_bins: Dict,
|
|
|
bin_policy: Optional[Dict] = None,
|
|
|
array_begin_idx=None,
|
|
|
):
|
|
|
self.tokenizer = tokenizer
|
|
|
self.num_roll_bins = num_bins["roll_bins"]
|
|
|
self.num_pitch_bins = num_bins["pitch_bins"]
|
|
|
self.num_yaw_bins = num_bins["yaw_bins"]
|
|
|
self.array_begin_idx = array_begin_idx
|
|
|
|
|
|
|
|
|
self.NP = self.num_pitch_bins * self.num_yaw_bins
|
|
|
|
|
|
|
|
|
self._vocab_size = self.num_roll_bins * self.num_pitch_bins * self.num_yaw_bins
|
|
|
token_list = [ACTION_TOKEN.format(i + self.array_begin_idx) for i in range(self._vocab_size)]
|
|
|
self.token_array = np.array(token_list)
|
|
|
|
|
|
num_new_tokens = self.tokenizer.add_tokens(token_list, special_tokens=True)
|
|
|
print(f"Add {num_new_tokens} ROTATION TOKENS to tokenizer, tokenizer vocab size {self.tokenizer.vocab_size} / {len(tokenizer)}")
|
|
|
|
|
|
self.token_start_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[0])
|
|
|
self.token_end_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[-1])
|
|
|
self.set_bins(bin_policy)
|
|
|
|
|
|
def set_bins(self, bin_policy):
|
|
|
self.roll_bins = np.array(bin_policy["roll_bins"])
|
|
|
self.pitch_bins = np.array(bin_policy["pitch_bins"])
|
|
|
self.yaw_bins = np.array(bin_policy["yaw_bins"])
|
|
|
|
|
|
def __call__(self, action: np.ndarray) -> List[str]:
|
|
|
"""Discretize continuous actions to tokens.
|
|
|
action: np.ndarray, (n, 3), continuous actions in Cartesian or Spherical coordinates.
|
|
|
return: np.ndarray, (n,), tokens.
|
|
|
"""
|
|
|
roll, pitch, yaw = action[:, 0], action[:, 1], action[:, 2]
|
|
|
disc_roll = np.clip(np.digitize(roll, self.roll_bins) - 1, 0, self.num_roll_bins - 1)
|
|
|
disc_pitch = np.clip(np.digitize(pitch, self.pitch_bins) - 1, 0, self.num_pitch_bins - 1)
|
|
|
disc_yaw = np.clip(np.digitize(yaw, self.yaw_bins) - 1, 0, self.num_yaw_bins - 1)
|
|
|
|
|
|
ids = disc_roll * self.NP + disc_pitch * self.num_yaw_bins + disc_yaw
|
|
|
return self.token_array[ids]
|
|
|
|
|
|
def decode_token_ids_to_actions(self, action_token_id: Union[np.int64, np.ndarray]) -> np.ndarray:
|
|
|
"""decode token ids to continuous actions.
|
|
|
action_token_id: np.ndarray, (n,), token ids.
|
|
|
return: np.ndarray, (n, 3), continuous actions
|
|
|
"""
|
|
|
action_token_id = np.clip(action_token_id, a_min=self.token_start_idx, a_max=self.token_end_idx)
|
|
|
ids = action_token_id - self.token_start_idx
|
|
|
disc_roll, disc_pitch, disc_yaw = ids // self.NP, (ids % self.NP) // self.num_yaw_bins, ids % self.num_yaw_bins
|
|
|
|
|
|
roll = 0.5 * (self.roll_bins[disc_roll] + self.roll_bins[disc_roll + 1])
|
|
|
pitch = 0.5 * (self.pitch_bins[disc_pitch] + self.pitch_bins[disc_pitch + 1])
|
|
|
yaw = 0.5 * (self.yaw_bins[disc_yaw] + self.yaw_bins[disc_yaw + 1])
|
|
|
return np.stack((roll, pitch, yaw), axis=1)
|
|
|
|
|
|
@property
|
|
|
def vocab_size(self) -> int:
|
|
|
return self._vocab_size
|
|
|
|
|
|
class GripperTokenzier:
|
|
|
def __init__(
|
|
|
self,
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
num_bins: int = 2,
|
|
|
array_begin_idx = None,
|
|
|
) -> None:
|
|
|
self.tokenizer = tokenizer
|
|
|
self.num_bins = num_bins
|
|
|
self.array_begin_idx = array_begin_idx
|
|
|
token_list = [ACTION_TOKEN.format(i + self.array_begin_idx) for i in range(self.num_bins)]
|
|
|
self.token_array = np.array(token_list)
|
|
|
|
|
|
num_new_tokens = self.tokenizer.add_tokens(token_list, special_tokens=True)
|
|
|
print(f"Add {num_new_tokens} GRIPPER TOKENS to tokenizer, tokenizer vocab size {self.tokenizer.vocab_size} / {len(tokenizer)}")
|
|
|
|
|
|
self.token_start_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[0])
|
|
|
self.token_end_idx = self.tokenizer.convert_tokens_to_ids(self.token_array[-1])
|
|
|
|
|
|
def __call__(self, action: np.ndarray) -> List[str]:
|
|
|
"""Discretize continuous actions to tokens.
|
|
|
action: np.ndarray, (n,), continuous actions in Cartesian or Spherical coordinates.
|
|
|
return: np.ndarray, (n,), tokens.
|
|
|
"""
|
|
|
ids = np.where(action >= 0.5, 1, 0)
|
|
|
return self.token_array[ids]
|
|
|
|
|
|
def decode_token_ids_to_actions(self, action_token_id: np.ndarray) -> np.ndarray:
|
|
|
"""decode token ids to continuous actions.
|
|
|
action_token_id: np.ndarray, (n,), token ids.
|
|
|
return: np.ndarray, (n, 1), continuous actions
|
|
|
"""
|
|
|
action_token_id = np.clip(action_token_id, self.token_start_idx, self.token_end_idx)
|
|
|
ids = action_token_id - self.token_start_idx
|
|
|
actions = np.where(ids == 0, 0., 1.)
|
|
|
return actions[:, None]
|
|
|
|
|
|
@property
|
|
|
def vocab_size(self) -> int:
|
|
|
return self.num_bins
|
|
|
|
|
|
class SpatialActionTokenizer:
|
|
|
range_bins = {
|
|
|
"translation": {
|
|
|
"theta_bins": (0.0, np.pi),
|
|
|
"phi_bins": (-np.pi, np.pi),
|
|
|
"r_bins": (0.0, np.sqrt(3)),
|
|
|
},
|
|
|
"rotation": {
|
|
|
"roll_bins": (-1.0, 1.0),
|
|
|
"pitch_bins": (-1.0, 1.0),
|
|
|
"yaw_bins": (-1.0, 1.0),
|
|
|
},
|
|
|
}
|
|
|
def __init__(
|
|
|
self,
|
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
|
num_bins: Dict,
|
|
|
gs_params: Dict = None,
|
|
|
bin_policy: Dict = None,
|
|
|
use_spherical: bool = True,
|
|
|
min_sigma: float = 0.0,
|
|
|
min_action: float = -1.0,
|
|
|
max_action: float = 1.0,
|
|
|
):
|
|
|
"""set bin_policy if exist, otherwise, caculate bin_policy from gs_params or use uniform bin grids.
|
|
|
gs_params: Optional[Dict],
|
|
|
bin_policy: Optional[Dict],
|
|
|
"""
|
|
|
self.tokenizer = tokenizer
|
|
|
self.min_action, self.max_action = min_action, max_action
|
|
|
self.num_bins = num_bins
|
|
|
self.min_sigma = min_sigma
|
|
|
|
|
|
|
|
|
self.bin_policy = bin_policy if bin_policy else self.get_bin_policy(gs_params, self.min_sigma)
|
|
|
self.translation_tokenizer = TranslationTokenizer(
|
|
|
self.tokenizer,
|
|
|
self.num_bins["translation"],
|
|
|
self.bin_policy["translation"],
|
|
|
use_spherical=use_spherical
|
|
|
)
|
|
|
|
|
|
self.rotation_tokenizer = RotationTokenizer(
|
|
|
self.tokenizer,
|
|
|
self.num_bins["rotation"],
|
|
|
self.bin_policy["rotation"],
|
|
|
array_begin_idx=self.translation_tokenizer.vocab_size,
|
|
|
)
|
|
|
|
|
|
self.gripper_tokenizer = GripperTokenzier(
|
|
|
self.tokenizer,
|
|
|
self.num_bins["gripper"],
|
|
|
array_begin_idx=self.translation_tokenizer.vocab_size + self.rotation_tokenizer.vocab_size
|
|
|
)
|
|
|
self._vocab_size = self.translation_tokenizer.vocab_size + self.rotation_tokenizer.vocab_size + self.gripper_tokenizer.vocab_size
|
|
|
|
|
|
def __call__(self, action: np.ndarray) -> List[str]:
|
|
|
"""Discretize continuous actions to tokens.
|
|
|
action: np.ndarray, (n, 7), continuous actions in Cartesian coordinates.
|
|
|
return: np.ndarray, (n, 3), tokens.
|
|
|
"""
|
|
|
if len(action.shape) == 1:
|
|
|
assert action.shape[0] == 7, f"action dim mismatch, got action shape: {action.shape}"
|
|
|
action = action.reshape(1, 7)
|
|
|
assert action.shape[1] == 7, f"action dim mismatch, got action shape: {action.shape}"
|
|
|
|
|
|
action = np.clip(action, a_min=self.min_action, a_max=self.max_action)
|
|
|
trans_tokens = self.translation_tokenizer(action[:, :3])
|
|
|
rot_tokens = self.rotation_tokenizer(action[:, 3:6])
|
|
|
grip_tokens = self.gripper_tokenizer(action[:, 6])
|
|
|
return np.stack((trans_tokens, rot_tokens, grip_tokens), axis=1)
|
|
|
|
|
|
def decode_token_ids_to_actions(self, action_token_ids: np.ndarray) -> np.ndarray:
|
|
|
"""decode token ids to continuous actions.
|
|
|
action_token_ids: np.ndarray, (n, 3), token ids.
|
|
|
"""
|
|
|
if len(action_token_ids.shape) == 1:
|
|
|
assert action_token_ids.shape[0] == 3, f"action token id numbers mismatich, need 3 got {action_token_ids.shape[0]}"
|
|
|
action_token_ids = action_token_ids.reshape(1, 3)
|
|
|
assert action_token_ids.shape[1] == 3, f"token id numbers mismatich, need 3 got {action_token_ids.shape[1]}"
|
|
|
|
|
|
trans_action = self.translation_tokenizer.decode_token_ids_to_actions(action_token_ids[:, 0])
|
|
|
rot_action = self.rotation_tokenizer.decode_token_ids_to_actions(action_token_ids[:, 1])
|
|
|
grip_action = self.gripper_tokenizer.decode_token_ids_to_actions(action_token_ids[:, 2])
|
|
|
return np.concatenate((trans_action, rot_action, grip_action), axis=1)
|
|
|
|
|
|
@property
|
|
|
def vocab_size(self) -> int:
|
|
|
return self._vocab_size
|
|
|
|
|
|
@property
|
|
|
def action_token_begin_idx(self) -> int:
|
|
|
return self.translation_tokenizer.token_start_idx
|
|
|
|
|
|
def get_bin_policy(self, gs_params=None, min_sigma=0.0):
|
|
|
bin_policy = {
|
|
|
"translation": {"theta_bins": None, "phi_bins": None, "r_bins": None},
|
|
|
"rotation": {"roll_bins": None, "pitch_bins": None, "yaw_bins": None}
|
|
|
}
|
|
|
if gs_params is None:
|
|
|
for bin_type in self.range_bins.keys():
|
|
|
for bin_key in self.range_bins[bin_type].keys():
|
|
|
bin_policy[bin_type][bin_key] = np.linspace(*self.range_bins[bin_type][bin_key], self.num_bins[bin_type][bin_key] + 1)
|
|
|
print(f"use unifrom bin grids ... \n{bin_policy}")
|
|
|
else:
|
|
|
for bin_type in self.range_bins.keys():
|
|
|
for bin_key in self.range_bins[bin_type].keys():
|
|
|
mu = gs_params[bin_key.split("_")[0].lower()]["mu"]
|
|
|
sigma = max(gs_params[bin_key.split("_")[0].lower()]["sigma"], min_sigma)
|
|
|
bin_bound_prob = np.linspace(
|
|
|
norm.cdf(self.range_bins[bin_type][bin_key][0], loc=mu, scale=sigma),
|
|
|
norm.cdf(self.range_bins[bin_type][bin_key][1], loc=mu, scale=sigma),
|
|
|
self.num_bins[bin_type][bin_key] + 1,
|
|
|
)
|
|
|
bin_boundary = norm.ppf(bin_bound_prob, loc=mu, scale=sigma)
|
|
|
bin_policy[bin_type][bin_key] = np.clip(
|
|
|
bin_boundary,
|
|
|
self.range_bins[bin_type][bin_key][0],
|
|
|
self.range_bins[bin_type][bin_key][1],
|
|
|
).tolist()
|
|
|
print(f"caculate bin grids from gaussians \n{bin_policy}")
|
|
|
return bin_policy
|
|
|
|
|
|
def get_norm_meshgrid(self, bin_policy):
|
|
|
grids = []
|
|
|
policy = {k1: {k2: np.array(v2) for k2, v2 in v1.items()} for k1, v1 in bin_policy.items()}
|
|
|
|
|
|
for bin_type in self.range_bins.keys():
|
|
|
bounds = []
|
|
|
for bin_key in self.range_bins[bin_type].keys():
|
|
|
minb, maxb = self.range_bins[bin_type][bin_key][0], self.range_bins[bin_type][bin_key][1]
|
|
|
bin_boundary = policy[bin_type][bin_key]
|
|
|
bin_center = (bin_boundary[:-1] + bin_boundary[1:]) / 2
|
|
|
bin_center = np.concatenate([np.array([minb]),bin_center,np.array([maxb])])
|
|
|
bin_center = (bin_center - minb) / (maxb - minb)
|
|
|
bounds.append(bin_center)
|
|
|
|
|
|
grid_x, grid_y, grid_z = np.meshgrid(*bounds)
|
|
|
grids += [np.stack([grid_x, grid_y, grid_z], -1).reshape(-1, 3)]
|
|
|
return grids[0], grids[1]
|
|
|
|
|
|
def spatial_embedding_adaption(self, gs_params, embeddings: torch.nn.Embedding, min_sigma=0.0, adpt_feature=False):
|
|
|
"""
|
|
|
gs_params0, gs_params1: Dict
|
|
|
embeddings: tensor (S,E)
|
|
|
"""
|
|
|
from scipy.interpolate import griddata
|
|
|
new_policy = self.get_bin_policy(gs_params, min_sigma=min_sigma)
|
|
|
trans_grids0, rot_grids0 = self.get_norm_meshgrid(self.bin_policy)
|
|
|
trans_grids1, rot_grids1 = self.get_norm_meshgrid(new_policy)
|
|
|
|
|
|
print("overwrite bin policy and tokenizer bins ...")
|
|
|
self.bin_policy = new_policy
|
|
|
self.min_sigma = min_sigma
|
|
|
self.translation_tokenizer.set_bins(new_policy["translation"])
|
|
|
self.rotation_tokenizer.set_bins(new_policy["rotation"])
|
|
|
|
|
|
if adpt_feature:
|
|
|
emb_data = embeddings.weight.data
|
|
|
_, E = emb_data.shape
|
|
|
|
|
|
|
|
|
m, n, k = (self.num_bins["translation"][k] for k in ["theta_bins", "phi_bins", "r_bins"])
|
|
|
N = m*n*k
|
|
|
trans_emb_data = emb_data[:N,].reshape(m, n, k, -1).permute(3, 0, 1, 2)
|
|
|
pad_emb = torch.nn.functional.pad(trans_emb_data, (1, 1, 1, 1, 1, 1), "replicate").permute(1, 2, 3, 0).reshape(-1, E)
|
|
|
adpt_trans_emb = griddata(trans_grids0, pad_emb.float(), trans_grids1, method='linear')
|
|
|
adpt_trans_emb = adpt_trans_emb.reshape(m+2, n+2, k+2, E)[1:-1, 1:-1, 1:-1,]
|
|
|
|
|
|
|
|
|
m1, n1, k1 = (self.num_bins["rotation"][k] for k in ["roll_bins", "pitch_bins", "yaw_bins"])
|
|
|
M = m1*n1*k1
|
|
|
rot_emb_data = emb_data[N : N + M,].reshape(m1, n1, k1, -1).permute(3, 0, 1, 2)
|
|
|
pad_emb = torch.nn.functional.pad(rot_emb_data, (1, 1, 1, 1, 1, 1), "replicate").permute(1, 2, 3, 0).reshape(-1, E)
|
|
|
adpt_rot_emb = griddata(rot_grids0, pad_emb.float(), rot_grids1, method='linear')
|
|
|
adpt_rot_emb = adpt_rot_emb.reshape(m1+2, n1+2, k1+2, E)[1:-1, 1:-1, 1:-1,]
|
|
|
|
|
|
|
|
|
device, dtype = embeddings.weight.data.device, embeddings.weight.data.dtype
|
|
|
embeddings.weight.data[:N] = torch.Tensor(adpt_trans_emb.reshape(-1, E), device=device).to(dtype)
|
|
|
embeddings.weight.data[N:N+M] = torch.Tensor(adpt_rot_emb.reshape(-1, E), device=device).to(dtype)
|
|
|
print("DONE! adapt spatial embedding to new gaussian distributation finished.")
|
|
|
print(embeddings.weight.data) |