| import copy | |
| from typing import Any, Dict | |
| from transformers import AutoConfig, PretrainedConfig | |
| class ActionCodecConfig(PretrainedConfig): | |
| model_type = "action_codec" | |
| def __init__( | |
| self, | |
| embodiment_config: Dict[str, Any] = None, | |
| n_tokens: int = 16, | |
| n_quantizers: int = 1, | |
| z_dim: int = 512, | |
| vq_type: str = "vq", | |
| vq_codebook_size: int = 2048, | |
| vq_commitment_weight: float = 0.25, | |
| vq_decay: float = 0.99, | |
| vq_kmeans_init: bool = True, | |
| vq_threshold_ema_dead_code: int = 2, | |
| vq_quantizer_dropout: float = 0.25, | |
| encoder_dim: int = 256, | |
| encoder_n_layers: int = 6, | |
| encoder_n_heads: int = 8, | |
| encoder_add_self_attn: bool = False, | |
| encoder_add_causal_mask: bool = False, | |
| encoder_pos_encoding_type: str = "fourier", | |
| decoder_dim: int = 256, | |
| decoder_n_layers: int = 6, | |
| decoder_n_heads: int = 8, | |
| decoder_add_self_attn: bool = False, | |
| decoder_add_causal_mask: bool = False, | |
| decoder_pos_encoding_type: str = "fourier", | |
| decoder_cls_size: int = 1, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| if embodiment_config is None: | |
| default_config = { | |
| "franka_libero_20hz": { | |
| "action_dim": 7, | |
| "freq": 20, | |
| "duration": 1, | |
| "description": "20Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).", | |
| }, | |
| "widowx_bridge_5hz": { | |
| "action_dim": 7, | |
| "freq": 5, | |
| "duration": 1, | |
| "description": "5Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).", | |
| }, | |
| "franka_droid_15hz": { | |
| "action_dim": 7, | |
| "freq": 15, | |
| "duration": 1, | |
| "description": "15Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).", | |
| }, | |
| } | |
| self.embodiment_config = copy.deepcopy(default_config) | |
| else: | |
| self.embodiment_config = copy.deepcopy(embodiment_config) | |
| self.n_tokens = n_tokens | |
| self.n_quantizers = n_quantizers | |
| self.z_dim = z_dim | |
| self.encoder_dim = encoder_dim | |
| self.encoder_n_layers = encoder_n_layers | |
| self.encoder_n_heads = encoder_n_heads | |
| self.encoder_add_self_attn = encoder_add_self_attn | |
| self.encoder_add_causal_mask = encoder_add_causal_mask | |
| self.encoder_pos_encoding_type = encoder_pos_encoding_type | |
| self.decoder_dim = decoder_dim | |
| self.decoder_n_layers = decoder_n_layers | |
| self.decoder_n_heads = decoder_n_heads | |
| self.decoder_add_self_attn = decoder_add_self_attn | |
| self.decoder_add_causal_mask = decoder_add_causal_mask | |
| self.decoder_pos_encoding_type = decoder_pos_encoding_type | |
| self.decoder_cls_size = decoder_cls_size | |
| self.vq_type = vq_type | |
| self.vq_codebook_size = vq_codebook_size | |
| self.vq_commitment_weight = vq_commitment_weight | |
| self.vq_decay = vq_decay | |
| self.vq_kmeans_init = vq_kmeans_init | |
| self.vq_threshold_ema_dead_code = vq_threshold_ema_dead_code | |
| self.vq_quantizer_dropout = vq_quantizer_dropout | |
| class ActionCodecConfigOld(PretrainedConfig): | |
| model_type = "action_codec" | |
| def __init__( | |
| self, | |
| horizon: int = 20, | |
| action_dim: int = 7, | |
| action_encoding: str = "independent_v2", | |
| horizon_patch_size: int = 1, | |
| encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder", | |
| decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder", | |
| vq_class: str = "vector_quantize_pytorch.VectorQuantize", | |
| encoder_kwargs: Dict[str, Any] = None, | |
| decoder_kwargs: Dict[str, Any] = None, | |
| vq_kwargs: Dict[str, Any] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.horizon = horizon | |
| self.action_dim = action_dim | |
| self.action_encoding = action_encoding | |
| self.horizon_patch_size = horizon_patch_size | |
| self.encoder_class = encoder_class | |
| self.decoder_class = decoder_class | |
| self.vq_class = vq_class | |
| self.encoder_kwargs = ( | |
| dict(encoder_kwargs) | |
| if encoder_kwargs is not None | |
| else { | |
| "dim": 384, | |
| "in_len": horizon, | |
| "out_len": 16, | |
| "num_layers": 12, | |
| "num_heads": 4, | |
| "output_round": -1.0, | |
| } | |
| ) | |
| self.decoder_kwargs = ( | |
| dict(decoder_kwargs) | |
| if decoder_kwargs is not None | |
| else { | |
| "dim": 384, | |
| "in_len": 16, | |
| "out_len": horizon, | |
| "num_layers": 12, | |
| "num_heads": 4, | |
| } | |
| ) | |
| self.vq_kwargs = ( | |
| dict(vq_kwargs) | |
| if vq_kwargs is not None | |
| else { | |
| "dim": 512, | |
| "codebook_size": 2048, | |
| "kmeans_init": True, | |
| "kmeans_iters": 10, | |
| "decay": 0.99, | |
| "commitment_weight": 0.25, | |
| "rotation_trick": False, | |
| "threshold_ema_dead_code": 2, | |
| "use_cosine_sim": False, | |
| "codebook_diversity_loss_weight": 0.0, | |
| } | |
| ) | |
| class BPEActionCodecConfig(PretrainedConfig): | |
| model_type = "bpe_action_codec" | |
| def __init__( | |
| self, | |
| horizon: int = 20, | |
| action_dim: int = 7, | |
| action_encoding: str = "independent_v2", | |
| horizon_patch_size: int = 1, | |
| encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder", | |
| decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder", | |
| vq_class: str = "vector_quantize_pytorch.VectorQuantize", | |
| encoder_kwargs: Dict[str, Any] = None, | |
| decoder_kwargs: Dict[str, Any] = None, | |
| vq_kwargs: Dict[str, Any] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.horizon = horizon | |
| self.action_dim = action_dim | |
| self.action_encoding = action_encoding | |
| self.horizon_patch_size = horizon_patch_size | |
| self.encoder_class = encoder_class | |
| self.decoder_class = decoder_class | |
| self.vq_class = vq_class | |
| self.encoder_kwargs = ( | |
| dict(encoder_kwargs) | |
| if encoder_kwargs is not None | |
| else { | |
| "dim": 384, | |
| "in_len": horizon, | |
| "out_len": 16, | |
| "num_layers": 12, | |
| "num_heads": 4, | |
| "output_round": -1.0, | |
| } | |
| ) | |
| self.decoder_kwargs = ( | |
| dict(decoder_kwargs) | |
| if decoder_kwargs is not None | |
| else { | |
| "dim": 384, | |
| "in_len": 16, | |
| "out_len": horizon, | |
| "num_layers": 12, | |
| "num_heads": 4, | |
| } | |
| ) | |
| self.vq_kwargs = ( | |
| dict(vq_kwargs) | |
| if vq_kwargs is not None | |
| else { | |
| "dim": 512, | |
| "codebook_size": 2048, | |
| "kmeans_init": True, | |
| "kmeans_iters": 10, | |
| "decay": 0.99, | |
| "commitment_weight": 0.25, | |
| "rotation_trick": False, | |
| "threshold_ema_dead_code": 2, | |
| "use_cosine_sim": False, | |
| "codebook_diversity_loss_weight": 0.0, | |
| } | |
| ) | |
| AutoConfig.register("action_codec", ActionCodecConfig) | |
| AutoConfig.register("bpe_action_codec", BPEActionCodecConfig) | |
| ActionCodecConfig.register_for_auto_class() | |
| __all__ = ["ActionCodecConfig", "BPEActionCodecConfig"] | |