ActionCodec-Base-RVQft / configuration_actioncodec.py
ZibinDong's picture
Upload folder using huggingface_hub
1bc40b2 verified
import copy
from typing import Any, Dict
from transformers import AutoConfig, PretrainedConfig
class ActionCodecConfig(PretrainedConfig):
model_type = "action_codec"
def __init__(
self,
embodiment_config: Dict[str, Any] = None,
n_tokens: int = 16,
n_quantizers: int = 1,
z_dim: int = 512,
vq_type: str = "vq",
vq_codebook_size: int = 2048,
vq_commitment_weight: float = 0.25,
vq_decay: float = 0.99,
vq_kmeans_init: bool = True,
vq_threshold_ema_dead_code: int = 2,
vq_quantizer_dropout: float = 0.25,
encoder_dim: int = 256,
encoder_n_layers: int = 6,
encoder_n_heads: int = 8,
encoder_add_self_attn: bool = False,
encoder_add_causal_mask: bool = False,
encoder_pos_encoding_type: str = "fourier",
decoder_dim: int = 256,
decoder_n_layers: int = 6,
decoder_n_heads: int = 8,
decoder_add_self_attn: bool = False,
decoder_add_causal_mask: bool = False,
decoder_pos_encoding_type: str = "fourier",
decoder_cls_size: int = 1,
**kwargs,
):
super().__init__(**kwargs)
if embodiment_config is None:
default_config = {
"franka_libero_20hz": {
"action_dim": 7,
"freq": 20,
"duration": 1,
"description": "20Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
},
"widowx_bridge_5hz": {
"action_dim": 7,
"freq": 5,
"duration": 1,
"description": "5Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
},
"franka_droid_15hz": {
"action_dim": 7,
"freq": 15,
"duration": 1,
"description": "15Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
},
}
self.embodiment_config = copy.deepcopy(default_config)
else:
self.embodiment_config = copy.deepcopy(embodiment_config)
self.n_tokens = n_tokens
self.n_quantizers = n_quantizers
self.z_dim = z_dim
self.encoder_dim = encoder_dim
self.encoder_n_layers = encoder_n_layers
self.encoder_n_heads = encoder_n_heads
self.encoder_add_self_attn = encoder_add_self_attn
self.encoder_add_causal_mask = encoder_add_causal_mask
self.encoder_pos_encoding_type = encoder_pos_encoding_type
self.decoder_dim = decoder_dim
self.decoder_n_layers = decoder_n_layers
self.decoder_n_heads = decoder_n_heads
self.decoder_add_self_attn = decoder_add_self_attn
self.decoder_add_causal_mask = decoder_add_causal_mask
self.decoder_pos_encoding_type = decoder_pos_encoding_type
self.decoder_cls_size = decoder_cls_size
self.vq_type = vq_type
self.vq_codebook_size = vq_codebook_size
self.vq_commitment_weight = vq_commitment_weight
self.vq_decay = vq_decay
self.vq_kmeans_init = vq_kmeans_init
self.vq_threshold_ema_dead_code = vq_threshold_ema_dead_code
self.vq_quantizer_dropout = vq_quantizer_dropout
class ActionCodecConfigOld(PretrainedConfig):
model_type = "action_codec"
def __init__(
self,
horizon: int = 20,
action_dim: int = 7,
action_encoding: str = "independent_v2",
horizon_patch_size: int = 1,
encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder",
decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder",
vq_class: str = "vector_quantize_pytorch.VectorQuantize",
encoder_kwargs: Dict[str, Any] = None,
decoder_kwargs: Dict[str, Any] = None,
vq_kwargs: Dict[str, Any] = None,
**kwargs,
):
super().__init__(**kwargs)
self.horizon = horizon
self.action_dim = action_dim
self.action_encoding = action_encoding
self.horizon_patch_size = horizon_patch_size
self.encoder_class = encoder_class
self.decoder_class = decoder_class
self.vq_class = vq_class
self.encoder_kwargs = (
dict(encoder_kwargs)
if encoder_kwargs is not None
else {
"dim": 384,
"in_len": horizon,
"out_len": 16,
"num_layers": 12,
"num_heads": 4,
"output_round": -1.0,
}
)
self.decoder_kwargs = (
dict(decoder_kwargs)
if decoder_kwargs is not None
else {
"dim": 384,
"in_len": 16,
"out_len": horizon,
"num_layers": 12,
"num_heads": 4,
}
)
self.vq_kwargs = (
dict(vq_kwargs)
if vq_kwargs is not None
else {
"dim": 512,
"codebook_size": 2048,
"kmeans_init": True,
"kmeans_iters": 10,
"decay": 0.99,
"commitment_weight": 0.25,
"rotation_trick": False,
"threshold_ema_dead_code": 2,
"use_cosine_sim": False,
"codebook_diversity_loss_weight": 0.0,
}
)
class BPEActionCodecConfig(PretrainedConfig):
model_type = "bpe_action_codec"
def __init__(
self,
horizon: int = 20,
action_dim: int = 7,
action_encoding: str = "independent_v2",
horizon_patch_size: int = 1,
encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder",
decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder",
vq_class: str = "vector_quantize_pytorch.VectorQuantize",
encoder_kwargs: Dict[str, Any] = None,
decoder_kwargs: Dict[str, Any] = None,
vq_kwargs: Dict[str, Any] = None,
**kwargs,
):
super().__init__(**kwargs)
self.horizon = horizon
self.action_dim = action_dim
self.action_encoding = action_encoding
self.horizon_patch_size = horizon_patch_size
self.encoder_class = encoder_class
self.decoder_class = decoder_class
self.vq_class = vq_class
self.encoder_kwargs = (
dict(encoder_kwargs)
if encoder_kwargs is not None
else {
"dim": 384,
"in_len": horizon,
"out_len": 16,
"num_layers": 12,
"num_heads": 4,
"output_round": -1.0,
}
)
self.decoder_kwargs = (
dict(decoder_kwargs)
if decoder_kwargs is not None
else {
"dim": 384,
"in_len": 16,
"out_len": horizon,
"num_layers": 12,
"num_heads": 4,
}
)
self.vq_kwargs = (
dict(vq_kwargs)
if vq_kwargs is not None
else {
"dim": 512,
"codebook_size": 2048,
"kmeans_init": True,
"kmeans_iters": 10,
"decay": 0.99,
"commitment_weight": 0.25,
"rotation_trick": False,
"threshold_ema_dead_code": 2,
"use_cosine_sim": False,
"codebook_diversity_loss_weight": 0.0,
}
)
AutoConfig.register("action_codec", ActionCodecConfig)
AutoConfig.register("bpe_action_codec", BPEActionCodecConfig)
ActionCodecConfig.register_for_auto_class()
__all__ = ["ActionCodecConfig", "BPEActionCodecConfig"]