Upload folder using huggingface_hub
Browse files- configuration_actioncodec.py +228 -0
- modeling_actioncodec.py +541 -0
- modular_actioncodec.py +779 -0
- rvq.py +522 -0
configuration_actioncodec.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
from transformers import AutoConfig, PretrainedConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ActionCodecConfig(PretrainedConfig):
|
| 8 |
+
model_type = "action_codec"
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
embodiment_config: Dict[str, Any] = None,
|
| 13 |
+
n_tokens: int = 16,
|
| 14 |
+
n_quantizers: int = 1,
|
| 15 |
+
z_dim: int = 512,
|
| 16 |
+
vq_type: str = "vq",
|
| 17 |
+
vq_codebook_size: int = 2048,
|
| 18 |
+
vq_commitment_weight: float = 0.25,
|
| 19 |
+
vq_decay: float = 0.99,
|
| 20 |
+
vq_kmeans_init: bool = True,
|
| 21 |
+
vq_threshold_ema_dead_code: int = 2,
|
| 22 |
+
vq_quantizer_dropout: float = 0.25,
|
| 23 |
+
encoder_dim: int = 256,
|
| 24 |
+
encoder_n_layers: int = 6,
|
| 25 |
+
encoder_n_heads: int = 8,
|
| 26 |
+
encoder_add_self_attn: bool = False,
|
| 27 |
+
encoder_add_causal_mask: bool = False,
|
| 28 |
+
encoder_pos_encoding_type: str = "fourier",
|
| 29 |
+
decoder_dim: int = 256,
|
| 30 |
+
decoder_n_layers: int = 6,
|
| 31 |
+
decoder_n_heads: int = 8,
|
| 32 |
+
decoder_add_self_attn: bool = False,
|
| 33 |
+
decoder_add_causal_mask: bool = False,
|
| 34 |
+
decoder_pos_encoding_type: str = "fourier",
|
| 35 |
+
decoder_cls_size: int = 1,
|
| 36 |
+
**kwargs,
|
| 37 |
+
):
|
| 38 |
+
super().__init__(**kwargs)
|
| 39 |
+
|
| 40 |
+
if embodiment_config is None:
|
| 41 |
+
default_config = {
|
| 42 |
+
"franka_libero_20hz": {
|
| 43 |
+
"action_dim": 7,
|
| 44 |
+
"freq": 20,
|
| 45 |
+
"duration": 1,
|
| 46 |
+
"description": "20Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 47 |
+
},
|
| 48 |
+
"widowx_bridge_5hz": {
|
| 49 |
+
"action_dim": 7,
|
| 50 |
+
"freq": 5,
|
| 51 |
+
"duration": 1,
|
| 52 |
+
"description": "5Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 53 |
+
},
|
| 54 |
+
"franka_droid_15hz": {
|
| 55 |
+
"action_dim": 7,
|
| 56 |
+
"freq": 15,
|
| 57 |
+
"duration": 1,
|
| 58 |
+
"description": "15Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).",
|
| 59 |
+
},
|
| 60 |
+
}
|
| 61 |
+
self.embodiment_config = copy.deepcopy(default_config)
|
| 62 |
+
else:
|
| 63 |
+
self.embodiment_config = copy.deepcopy(embodiment_config)
|
| 64 |
+
|
| 65 |
+
self.n_tokens = n_tokens
|
| 66 |
+
self.n_quantizers = n_quantizers
|
| 67 |
+
self.z_dim = z_dim
|
| 68 |
+
|
| 69 |
+
self.encoder_dim = encoder_dim
|
| 70 |
+
self.encoder_n_layers = encoder_n_layers
|
| 71 |
+
self.encoder_n_heads = encoder_n_heads
|
| 72 |
+
self.encoder_add_self_attn = encoder_add_self_attn
|
| 73 |
+
self.encoder_add_causal_mask = encoder_add_causal_mask
|
| 74 |
+
self.encoder_pos_encoding_type = encoder_pos_encoding_type
|
| 75 |
+
|
| 76 |
+
self.decoder_dim = decoder_dim
|
| 77 |
+
self.decoder_n_layers = decoder_n_layers
|
| 78 |
+
self.decoder_n_heads = decoder_n_heads
|
| 79 |
+
self.decoder_add_self_attn = decoder_add_self_attn
|
| 80 |
+
self.decoder_add_causal_mask = decoder_add_causal_mask
|
| 81 |
+
self.decoder_pos_encoding_type = decoder_pos_encoding_type
|
| 82 |
+
self.decoder_cls_size = decoder_cls_size
|
| 83 |
+
|
| 84 |
+
self.vq_type = vq_type
|
| 85 |
+
self.vq_codebook_size = vq_codebook_size
|
| 86 |
+
self.vq_commitment_weight = vq_commitment_weight
|
| 87 |
+
self.vq_decay = vq_decay
|
| 88 |
+
self.vq_kmeans_init = vq_kmeans_init
|
| 89 |
+
self.vq_threshold_ema_dead_code = vq_threshold_ema_dead_code
|
| 90 |
+
self.vq_quantizer_dropout = vq_quantizer_dropout
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class ActionCodecConfigOld(PretrainedConfig):
|
| 94 |
+
model_type = "action_codec"
|
| 95 |
+
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
horizon: int = 20,
|
| 99 |
+
action_dim: int = 7,
|
| 100 |
+
action_encoding: str = "independent_v2",
|
| 101 |
+
horizon_patch_size: int = 1,
|
| 102 |
+
encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder",
|
| 103 |
+
decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder",
|
| 104 |
+
vq_class: str = "vector_quantize_pytorch.VectorQuantize",
|
| 105 |
+
encoder_kwargs: Dict[str, Any] = None,
|
| 106 |
+
decoder_kwargs: Dict[str, Any] = None,
|
| 107 |
+
vq_kwargs: Dict[str, Any] = None,
|
| 108 |
+
**kwargs,
|
| 109 |
+
):
|
| 110 |
+
super().__init__(**kwargs)
|
| 111 |
+
self.horizon = horizon
|
| 112 |
+
self.action_dim = action_dim
|
| 113 |
+
self.action_encoding = action_encoding
|
| 114 |
+
self.horizon_patch_size = horizon_patch_size
|
| 115 |
+
self.encoder_class = encoder_class
|
| 116 |
+
self.decoder_class = decoder_class
|
| 117 |
+
self.vq_class = vq_class
|
| 118 |
+
self.encoder_kwargs = (
|
| 119 |
+
dict(encoder_kwargs)
|
| 120 |
+
if encoder_kwargs is not None
|
| 121 |
+
else {
|
| 122 |
+
"dim": 384,
|
| 123 |
+
"in_len": horizon,
|
| 124 |
+
"out_len": 16,
|
| 125 |
+
"num_layers": 12,
|
| 126 |
+
"num_heads": 4,
|
| 127 |
+
"output_round": -1.0,
|
| 128 |
+
}
|
| 129 |
+
)
|
| 130 |
+
self.decoder_kwargs = (
|
| 131 |
+
dict(decoder_kwargs)
|
| 132 |
+
if decoder_kwargs is not None
|
| 133 |
+
else {
|
| 134 |
+
"dim": 384,
|
| 135 |
+
"in_len": 16,
|
| 136 |
+
"out_len": horizon,
|
| 137 |
+
"num_layers": 12,
|
| 138 |
+
"num_heads": 4,
|
| 139 |
+
}
|
| 140 |
+
)
|
| 141 |
+
self.vq_kwargs = (
|
| 142 |
+
dict(vq_kwargs)
|
| 143 |
+
if vq_kwargs is not None
|
| 144 |
+
else {
|
| 145 |
+
"dim": 512,
|
| 146 |
+
"codebook_size": 2048,
|
| 147 |
+
"kmeans_init": True,
|
| 148 |
+
"kmeans_iters": 10,
|
| 149 |
+
"decay": 0.99,
|
| 150 |
+
"commitment_weight": 0.25,
|
| 151 |
+
"rotation_trick": False,
|
| 152 |
+
"threshold_ema_dead_code": 2,
|
| 153 |
+
"use_cosine_sim": False,
|
| 154 |
+
"codebook_diversity_loss_weight": 0.0,
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class BPEActionCodecConfig(PretrainedConfig):
|
| 160 |
+
model_type = "bpe_action_codec"
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
horizon: int = 20,
|
| 165 |
+
action_dim: int = 7,
|
| 166 |
+
action_encoding: str = "independent_v2",
|
| 167 |
+
horizon_patch_size: int = 1,
|
| 168 |
+
encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder",
|
| 169 |
+
decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder",
|
| 170 |
+
vq_class: str = "vector_quantize_pytorch.VectorQuantize",
|
| 171 |
+
encoder_kwargs: Dict[str, Any] = None,
|
| 172 |
+
decoder_kwargs: Dict[str, Any] = None,
|
| 173 |
+
vq_kwargs: Dict[str, Any] = None,
|
| 174 |
+
**kwargs,
|
| 175 |
+
):
|
| 176 |
+
super().__init__(**kwargs)
|
| 177 |
+
self.horizon = horizon
|
| 178 |
+
self.action_dim = action_dim
|
| 179 |
+
self.action_encoding = action_encoding
|
| 180 |
+
self.horizon_patch_size = horizon_patch_size
|
| 181 |
+
self.encoder_class = encoder_class
|
| 182 |
+
self.decoder_class = decoder_class
|
| 183 |
+
self.vq_class = vq_class
|
| 184 |
+
self.encoder_kwargs = (
|
| 185 |
+
dict(encoder_kwargs)
|
| 186 |
+
if encoder_kwargs is not None
|
| 187 |
+
else {
|
| 188 |
+
"dim": 384,
|
| 189 |
+
"in_len": horizon,
|
| 190 |
+
"out_len": 16,
|
| 191 |
+
"num_layers": 12,
|
| 192 |
+
"num_heads": 4,
|
| 193 |
+
"output_round": -1.0,
|
| 194 |
+
}
|
| 195 |
+
)
|
| 196 |
+
self.decoder_kwargs = (
|
| 197 |
+
dict(decoder_kwargs)
|
| 198 |
+
if decoder_kwargs is not None
|
| 199 |
+
else {
|
| 200 |
+
"dim": 384,
|
| 201 |
+
"in_len": 16,
|
| 202 |
+
"out_len": horizon,
|
| 203 |
+
"num_layers": 12,
|
| 204 |
+
"num_heads": 4,
|
| 205 |
+
}
|
| 206 |
+
)
|
| 207 |
+
self.vq_kwargs = (
|
| 208 |
+
dict(vq_kwargs)
|
| 209 |
+
if vq_kwargs is not None
|
| 210 |
+
else {
|
| 211 |
+
"dim": 512,
|
| 212 |
+
"codebook_size": 2048,
|
| 213 |
+
"kmeans_init": True,
|
| 214 |
+
"kmeans_iters": 10,
|
| 215 |
+
"decay": 0.99,
|
| 216 |
+
"commitment_weight": 0.25,
|
| 217 |
+
"rotation_trick": False,
|
| 218 |
+
"threshold_ema_dead_code": 2,
|
| 219 |
+
"use_cosine_sim": False,
|
| 220 |
+
"codebook_diversity_loss_weight": 0.0,
|
| 221 |
+
}
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
AutoConfig.register("action_codec", ActionCodecConfig)
|
| 226 |
+
AutoConfig.register("bpe_action_codec", BPEActionCodecConfig)
|
| 227 |
+
|
| 228 |
+
__all__ = ["ActionCodecConfig", "BPEActionCodecConfig"]
|
modeling_actioncodec.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import einops
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoModel, PreTrainedModel
|
| 7 |
+
from vector_quantize_pytorch import VectorQuantize
|
| 8 |
+
|
| 9 |
+
from .configuration_actioncodec import ActionCodecConfig
|
| 10 |
+
from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder
|
| 11 |
+
from .rvq import ResidualVectorQuantize
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]:
|
| 15 |
+
if arr.shape[0] == 0:
|
| 16 |
+
return []
|
| 17 |
+
|
| 18 |
+
b, n = arr.shape
|
| 19 |
+
|
| 20 |
+
is_nonzero = arr != 0
|
| 21 |
+
flipped_mask = np.flip(is_nonzero, axis=1)
|
| 22 |
+
last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1)
|
| 23 |
+
any_nonzero_in_row = is_nonzero.any(axis=1)
|
| 24 |
+
new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row
|
| 25 |
+
result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)]
|
| 26 |
+
|
| 27 |
+
return result
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ActionCodec(PreTrainedModel):
|
| 31 |
+
config_class = ActionCodecConfig
|
| 32 |
+
|
| 33 |
+
def __init__(self, config: ActionCodecConfig):
|
| 34 |
+
super().__init__(config)
|
| 35 |
+
self.default_embodiment_id = 0
|
| 36 |
+
|
| 37 |
+
self.encoder = PerceiverEncoder(config)
|
| 38 |
+
self.decoder = PerceiverDecoder(config)
|
| 39 |
+
|
| 40 |
+
if config.vq_type == "vq":
|
| 41 |
+
assert config.n_quantizers == 1, "Only one quantizer is supported for VQ"
|
| 42 |
+
self.vq = VectorQuantize(
|
| 43 |
+
dim=config.z_dim,
|
| 44 |
+
codebook_size=config.vq_codebook_size,
|
| 45 |
+
commitment_weight=config.vq_commitment_weight,
|
| 46 |
+
decay=config.vq_decay,
|
| 47 |
+
kmeans_init=config.vq_kmeans_init,
|
| 48 |
+
threshold_ema_dead_code=config.vq_threshold_ema_dead_code,
|
| 49 |
+
rotation_trick=False,
|
| 50 |
+
straight_through=True,
|
| 51 |
+
)
|
| 52 |
+
elif config.vq_type == "rvq":
|
| 53 |
+
assert config.n_quantizers > 1, "At least two quantizers are supported for RVQ"
|
| 54 |
+
self.vq = ResidualVectorQuantize(
|
| 55 |
+
dim=config.z_dim,
|
| 56 |
+
n_codebooks=config.n_quantizers,
|
| 57 |
+
codebook_size=config.vq_codebook_size,
|
| 58 |
+
codebook_dim=config.z_dim,
|
| 59 |
+
quantizer_dropout=config.vq_quantizer_dropout,
|
| 60 |
+
commitment=config.vq_commitment_weight,
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError(f"VQ type {config.vq_type} not implemented")
|
| 64 |
+
|
| 65 |
+
self.vocab_size = config.vq_codebook_size
|
| 66 |
+
self.num_quantizers = config.n_quantizers
|
| 67 |
+
self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers
|
| 68 |
+
|
| 69 |
+
def expand_embodiment(self, embodiment_config: dict):
|
| 70 |
+
"""
|
| 71 |
+
Delegates expansion to the underlying Encoder and Decoder.
|
| 72 |
+
This allows the Codec to adapt to new robots dynamically.
|
| 73 |
+
"""
|
| 74 |
+
self.encoder.expand_embodiment(embodiment_config)
|
| 75 |
+
self.decoder.expand_embodiment(embodiment_config)
|
| 76 |
+
self.config.embodiment_config.update(embodiment_config)
|
| 77 |
+
return self
|
| 78 |
+
|
| 79 |
+
def _encode(
|
| 80 |
+
self,
|
| 81 |
+
x: torch.Tensor,
|
| 82 |
+
embodiment_ids: torch.Tensor | int | None = None,
|
| 83 |
+
padding_mask: torch.Tensor | None = None,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""Encode action sequences into latent representations.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
|
| 89 |
+
Assumes that the action dimension is zero-padded to the max action dimension.
|
| 90 |
+
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
|
| 91 |
+
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
|
| 92 |
+
If int, the same embodiment ID is repeated for all sequences in the batch.
|
| 93 |
+
It specifies the embodiment to encode.
|
| 94 |
+
padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
|
| 95 |
+
It is used to mask the padding tokens on `seq_len` dimension.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 99 |
+
"""
|
| 100 |
+
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
|
| 101 |
+
z_e = self.encoder(x, embodiment_ids, padding_mask)
|
| 102 |
+
return z_e
|
| 103 |
+
|
| 104 |
+
def _quantize(self, z_e: torch.Tensor, return_perplexity: bool = True) -> List[torch.Tensor]:
|
| 105 |
+
if isinstance(self.vq, ResidualVectorQuantize):
|
| 106 |
+
z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e)
|
| 107 |
+
commit_loss = commitment_loss.mean() + codebook_loss.mean()
|
| 108 |
+
elif isinstance(self.vq, VectorQuantize):
|
| 109 |
+
z_q, indices, commit_loss = self.vq(z_e)
|
| 110 |
+
else:
|
| 111 |
+
raise NotImplementedError(f"VQ type {type(self.vq)} not implemented")
|
| 112 |
+
|
| 113 |
+
if return_perplexity:
|
| 114 |
+
if len(indices.size()) < 3:
|
| 115 |
+
indices = indices.unsqueeze(-1)
|
| 116 |
+
perplexity = []
|
| 117 |
+
for k in range(indices.size(-1)):
|
| 118 |
+
this_indices = indices[:, :, k]
|
| 119 |
+
indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size)
|
| 120 |
+
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
|
| 121 |
+
torch.distributed.all_reduce(indices_count)
|
| 122 |
+
this_avg_probs = indices_count.float() / indices_count.sum()
|
| 123 |
+
perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item()))
|
| 124 |
+
else:
|
| 125 |
+
perplexity = 0
|
| 126 |
+
|
| 127 |
+
return z_q, indices, perplexity, commit_loss
|
| 128 |
+
|
| 129 |
+
def _dequantize(self, indices: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
if self.num_quantizers == 1:
|
| 131 |
+
if len(indices.size()) == 3:
|
| 132 |
+
indices = indices.squeeze(-1)
|
| 133 |
+
if isinstance(self.vq, ResidualVectorQuantize):
|
| 134 |
+
z_q = self.vq.from_codes(indices)[0]
|
| 135 |
+
else:
|
| 136 |
+
z_q = self.vq.get_output_from_indices(indices)
|
| 137 |
+
return z_q
|
| 138 |
+
|
| 139 |
+
def _decode(
|
| 140 |
+
self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None
|
| 141 |
+
) -> torch.Tensor:
|
| 142 |
+
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
|
| 143 |
+
x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations)
|
| 144 |
+
return x_recon, padding_mask
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def encode(
|
| 148 |
+
self,
|
| 149 |
+
x: np.ndarray,
|
| 150 |
+
embodiment_ids: List[int] | int | None = None,
|
| 151 |
+
padding_mask: List[bool] | None = None,
|
| 152 |
+
) -> List[List[int]]:
|
| 153 |
+
"""Encode action sequences into latent representations.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
x (np.ndarray): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
|
| 157 |
+
Assumes that the action dimension is zero-padded to the max action dimension.
|
| 158 |
+
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
|
| 159 |
+
embodiment_ids (List[int] | int): Embodiment IDs. Shape: (b,).
|
| 160 |
+
If int, the same embodiment ID is repeated for all sequences in the batch.
|
| 161 |
+
It specifies the embodiment to encode.
|
| 162 |
+
padding_mask (List[bool] | None): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
|
| 163 |
+
It is used to mask the padding tokens on `seq_len` dimension.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
List[List[int]]: List of token sequences. Shape: (b, n_tokens).
|
| 167 |
+
"""
|
| 168 |
+
self.eval()
|
| 169 |
+
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
|
| 170 |
+
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device)
|
| 173 |
+
if not isinstance(embodiment_ids, int):
|
| 174 |
+
embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
|
| 175 |
+
if padding_mask is not None:
|
| 176 |
+
padding_mask = torch.tensor(padding_mask, dtype=torch.bool, device=self.device)
|
| 177 |
+
|
| 178 |
+
z_e = self._encode(x_tensor, embodiment_ids, padding_mask)
|
| 179 |
+
_, indices, _, _ = self._quantize(z_e, return_perplexity=False)
|
| 180 |
+
if len(indices.size()) > 2:
|
| 181 |
+
codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu()
|
| 182 |
+
else:
|
| 183 |
+
codes_list = indices.cpu()
|
| 184 |
+
codes_list = codes_list.tolist()
|
| 185 |
+
return codes_list
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def decode(
|
| 189 |
+
self, tokens: List[List[int]], embodiment_ids: List[int] | int | None = None, durations: List[float] | None = None
|
| 190 |
+
) -> np.ndarray:
|
| 191 |
+
self.eval()
|
| 192 |
+
embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id
|
| 193 |
+
tokens = torch.tensor(tokens, dtype=torch.long, device=self.device)
|
| 194 |
+
if not isinstance(embodiment_ids, int):
|
| 195 |
+
embodiment_ids = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device)
|
| 196 |
+
if durations is not None:
|
| 197 |
+
durations = torch.tensor(durations, dtype=torch.float32, device=self.device)
|
| 198 |
+
|
| 199 |
+
b, n = tokens.shape
|
| 200 |
+
assert n % self.n_tokens_per_quantizer == 0, (
|
| 201 |
+
f"Expected {self.n_tokens_per_quantizer} tokens per quantizer, got {n} in total."
|
| 202 |
+
)
|
| 203 |
+
indices = einops.rearrange(tokens, "b (n m) -> b m n", m=self.n_tokens_per_quantizer)
|
| 204 |
+
z_q = self._dequantize(indices)
|
| 205 |
+
x_recon, padding_mask = self._decode(z_q, embodiment_ids, durations)
|
| 206 |
+
return x_recon.cpu().numpy(), padding_mask.cpu().numpy()
|
| 207 |
+
|
| 208 |
+
# def sparse_encode(
|
| 209 |
+
# self,
|
| 210 |
+
# x: np.ndarray,
|
| 211 |
+
# search_num: int = 10,
|
| 212 |
+
# threshold: float = 0.1,
|
| 213 |
+
# action_encoding: str | None = None,
|
| 214 |
+
# remove_padding: bool = True,
|
| 215 |
+
# ) -> List[List[int]]:
|
| 216 |
+
# """
|
| 217 |
+
# Sparse encoding with adaptive token selection based on reconstruction error threshold.
|
| 218 |
+
# Uses quaternary search to find optimal token length.
|
| 219 |
+
|
| 220 |
+
# Args:
|
| 221 |
+
# x: Input action arrays of shape (b, n, d)
|
| 222 |
+
# search_num: Maximum number of search iterations
|
| 223 |
+
# threshold: Reconstruction error threshold
|
| 224 |
+
# action_encoding: Action encoding type
|
| 225 |
+
# remove_padding: Whether to remove trailing zeros
|
| 226 |
+
|
| 227 |
+
# Returns:
|
| 228 |
+
# List of sparse token sequences
|
| 229 |
+
# """
|
| 230 |
+
# self.eval()
|
| 231 |
+
# with torch.no_grad():
|
| 232 |
+
# x_tensor = self._numpy_to_tensor(x)
|
| 233 |
+
|
| 234 |
+
# # Get initial encoding
|
| 235 |
+
# z_e = self._encode(x_tensor, action_encoding)
|
| 236 |
+
# _, indices, _, _ = self._quantize(z_e, return_perplexity=False)
|
| 237 |
+
|
| 238 |
+
# # Convert indices to proper format
|
| 239 |
+
# if len(indices.size()) > 2:
|
| 240 |
+
# indices_flat = einops.rearrange(indices, "b n s -> b (s n)")
|
| 241 |
+
# else:
|
| 242 |
+
# indices_flat = indices
|
| 243 |
+
|
| 244 |
+
# # Use quaternary search to find optimal token lengths
|
| 245 |
+
# optimal_lengths = self._quaternary_search(x_tensor, indices_flat, threshold, search_num, action_encoding)
|
| 246 |
+
|
| 247 |
+
# # Create final sparse tokens based on optimal lengths
|
| 248 |
+
# final_tokens = self._create_sparse_tokens_from_lengths(indices_flat, optimal_lengths)
|
| 249 |
+
|
| 250 |
+
# # Convert to list format
|
| 251 |
+
# if remove_padding:
|
| 252 |
+
# final_tokens = trim_trailing_zeros(final_tokens.cpu().numpy())
|
| 253 |
+
# else:
|
| 254 |
+
# final_tokens = final_tokens.cpu().tolist()
|
| 255 |
+
|
| 256 |
+
# return final_tokens
|
| 257 |
+
|
| 258 |
+
# def _quaternary_search(
|
| 259 |
+
# self,
|
| 260 |
+
# x_tensor: torch.Tensor,
|
| 261 |
+
# indices_flat: torch.Tensor,
|
| 262 |
+
# threshold: float,
|
| 263 |
+
# search_num: int,
|
| 264 |
+
# action_encoding: str | None = None,
|
| 265 |
+
# ) -> torch.Tensor:
|
| 266 |
+
# """
|
| 267 |
+
# Quaternary search to find optimal token lengths for each batch item.
|
| 268 |
+
# Returns tensor of shape (batch_size,) containing optimal lengths.
|
| 269 |
+
# """
|
| 270 |
+
# batch_size, seq_len = indices_flat.shape
|
| 271 |
+
|
| 272 |
+
# # Initialize search bounds
|
| 273 |
+
# device = indices_flat.device
|
| 274 |
+
# left = torch.ones(batch_size, dtype=torch.long, device=device)
|
| 275 |
+
# right = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
|
| 276 |
+
|
| 277 |
+
# # Perform quaternary search
|
| 278 |
+
# for _ in range(search_num):
|
| 279 |
+
# # Calculate three division points
|
| 280 |
+
# range_size = right - left
|
| 281 |
+
# q1 = left + range_size // 4
|
| 282 |
+
# q2 = left + range_size // 2
|
| 283 |
+
# q3 = left + 3 * range_size // 4
|
| 284 |
+
|
| 285 |
+
# # Ensure q1, q2, q3 are within bounds and distinct
|
| 286 |
+
# q1 = torch.clamp(q1, left, right)
|
| 287 |
+
# q2 = torch.clamp(q2, q1 + 1, right)
|
| 288 |
+
# q3 = torch.clamp(q3, q2 + 1, right)
|
| 289 |
+
|
| 290 |
+
# # Create test lengths: [left, q1, q2, q3, right]
|
| 291 |
+
# test_lengths = torch.stack([left, q1, q2, q3, right], dim=1) # (batch_size, 5)
|
| 292 |
+
|
| 293 |
+
# # Calculate errors for all test lengths
|
| 294 |
+
# errors = self._calculate_errors_for_lengths(x_tensor, indices_flat, test_lengths, action_encoding)
|
| 295 |
+
|
| 296 |
+
# # Update search bounds based on results (vectorized)
|
| 297 |
+
# # Find which lengths meet threshold for each batch item
|
| 298 |
+
# meets_threshold = errors <= threshold
|
| 299 |
+
|
| 300 |
+
# # For each batch item, find the smallest length that meets threshold
|
| 301 |
+
# valid_indices = torch.argmax(meets_threshold.float(), dim=1) # First True index
|
| 302 |
+
# has_valid = meets_threshold.any(dim=1) # Whether any length meets threshold
|
| 303 |
+
|
| 304 |
+
# # Create batch indices for advanced indexing
|
| 305 |
+
# batch_indices = torch.arange(batch_size, device=device)
|
| 306 |
+
|
| 307 |
+
# # Get the smallest valid length for each batch
|
| 308 |
+
# smallest_valid_lengths = test_lengths[batch_indices, valid_indices]
|
| 309 |
+
|
| 310 |
+
# # Update bounds based on results
|
| 311 |
+
# # If has valid length, use it; otherwise use longest length
|
| 312 |
+
# right = torch.where(has_valid, smallest_valid_lengths, test_lengths[:, -1])
|
| 313 |
+
|
| 314 |
+
# # Update left bound: if we found a valid length and it's not the first one,
|
| 315 |
+
# # use the previous length; otherwise keep current left
|
| 316 |
+
# prev_lengths = torch.where(valid_indices > 0, test_lengths[batch_indices, valid_indices - 1], left)
|
| 317 |
+
# left = torch.where(has_valid & (valid_indices > 0), prev_lengths, left)
|
| 318 |
+
|
| 319 |
+
# # Check convergence
|
| 320 |
+
# if (right - left).max() <= 1:
|
| 321 |
+
# break
|
| 322 |
+
|
| 323 |
+
# return right # Return optimal lengths
|
| 324 |
+
|
| 325 |
+
# def _calculate_errors_for_lengths(
|
| 326 |
+
# self,
|
| 327 |
+
# x_tensor: torch.Tensor,
|
| 328 |
+
# indices_flat: torch.Tensor,
|
| 329 |
+
# test_lengths: torch.Tensor,
|
| 330 |
+
# action_encoding: str | None = None,
|
| 331 |
+
# ) -> torch.Tensor:
|
| 332 |
+
# """
|
| 333 |
+
# Calculate reconstruction errors for given token lengths.
|
| 334 |
+
|
| 335 |
+
# Args:
|
| 336 |
+
# x_tensor: Original input tensor (batch_size, ...)
|
| 337 |
+
# indices_flat: Full token indices (batch_size, seq_len)
|
| 338 |
+
# test_lengths: Test lengths tensor (batch_size, num_tests)
|
| 339 |
+
# action_encoding: Action encoding type
|
| 340 |
+
|
| 341 |
+
# Returns:
|
| 342 |
+
# Error tensor (batch_size, num_tests)
|
| 343 |
+
# """
|
| 344 |
+
# # Create sparse tokens for all test lengths (vectorized)
|
| 345 |
+
# batch_size, num_tests = test_lengths.shape
|
| 346 |
+
# seq_len = indices_flat.shape[1]
|
| 347 |
+
# device = indices_flat.device
|
| 348 |
+
|
| 349 |
+
# # Create position tensor for all combinations
|
| 350 |
+
# positions = torch.arange(seq_len, device=device).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len)
|
| 351 |
+
# positions = positions.expand(batch_size, num_tests, -1) # (batch_size, num_tests, seq_len)
|
| 352 |
+
|
| 353 |
+
# # Create length mask: positions < test_lengths
|
| 354 |
+
# length_mask = positions < test_lengths.unsqueeze(2) # (batch_size, num_tests, seq_len)
|
| 355 |
+
|
| 356 |
+
# # Create sparse tokens using advanced indexing
|
| 357 |
+
# sparse_tokens = torch.where(
|
| 358 |
+
# length_mask,
|
| 359 |
+
# indices_flat.unsqueeze(1).expand(-1, num_tests, -1),
|
| 360 |
+
# torch.zeros_like(indices_flat).unsqueeze(1).expand(-1, num_tests, -1),
|
| 361 |
+
# )
|
| 362 |
+
|
| 363 |
+
# # Reshape for parallel processing
|
| 364 |
+
# sparse_flat = sparse_tokens.view(batch_size * num_tests, seq_len)
|
| 365 |
+
|
| 366 |
+
# # Decode all sparse tokens in parallel
|
| 367 |
+
# reconstructed_flat = self._decode_sparse_tokens(sparse_flat, action_encoding)
|
| 368 |
+
|
| 369 |
+
# # Reshape back and calculate errors
|
| 370 |
+
# reconstructed = reconstructed_flat.view(batch_size, num_tests, *x_tensor.shape[1:])
|
| 371 |
+
|
| 372 |
+
# # Calculate errors
|
| 373 |
+
# x_expanded = x_tensor.unsqueeze(1).expand(-1, num_tests, -1, -1)
|
| 374 |
+
# errors = (x_expanded - reconstructed).abs().mean((-1, -2)) # (batch_size, num_tests)
|
| 375 |
+
|
| 376 |
+
# return errors
|
| 377 |
+
|
| 378 |
+
# def _decode_sparse_tokens(self, sparse_tokens: torch.Tensor, action_encoding: str | None = None) -> torch.Tensor:
|
| 379 |
+
# """Decode sparse tokens to reconstructed data."""
|
| 380 |
+
# batch_size, seq_len = sparse_tokens.shape
|
| 381 |
+
|
| 382 |
+
# # Convert to proper indices format for dequantization
|
| 383 |
+
# if self.num_quantizers > 1:
|
| 384 |
+
# seq_len_per_quantizer = seq_len // self.num_quantizers
|
| 385 |
+
# if seq_len % self.num_quantizers != 0:
|
| 386 |
+
# raise ValueError("Sequence length must be divisible by num_quantizers")
|
| 387 |
+
|
| 388 |
+
# indices_for_decode = sparse_tokens.view(batch_size, self.num_quantizers, seq_len_per_quantizer).transpose(
|
| 389 |
+
# 1, 2
|
| 390 |
+
# ) # (batch_size, seq_len_per_quantizer, num_quantizers)
|
| 391 |
+
# else:
|
| 392 |
+
# indices_for_decode = sparse_tokens.unsqueeze(-1) # (batch_size, seq_len, 1)
|
| 393 |
+
|
| 394 |
+
# # Dequantize and decode
|
| 395 |
+
# z_q = self._dequantize(indices_for_decode)
|
| 396 |
+
# reconstructed = self._decode(z_q, action_encoding)
|
| 397 |
+
|
| 398 |
+
# return reconstructed
|
| 399 |
+
|
| 400 |
+
# def _create_sparse_tokens_from_lengths(
|
| 401 |
+
# self, indices_flat: torch.Tensor, optimal_lengths: torch.Tensor
|
| 402 |
+
# ) -> torch.Tensor:
|
| 403 |
+
# """Create sparse tokens based on optimal lengths (vectorized)."""
|
| 404 |
+
# batch_size, seq_len = indices_flat.shape
|
| 405 |
+
# device = indices_flat.device
|
| 406 |
+
|
| 407 |
+
# # Create position mask for all batch items simultaneously
|
| 408 |
+
# positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) # (batch_size, seq_len)
|
| 409 |
+
# length_mask = positions < optimal_lengths.unsqueeze(1) # (batch_size, seq_len)
|
| 410 |
+
|
| 411 |
+
# # Apply mask to create sparse tokens
|
| 412 |
+
# result = torch.where(length_mask, indices_flat, torch.zeros_like(indices_flat))
|
| 413 |
+
|
| 414 |
+
# return result
|
| 415 |
+
|
| 416 |
+
def forward(self, x: torch.Tensor, embodiment_ids: int | None = None, padding_mask: List[bool] | None = None):
|
| 417 |
+
return self.encode(x, embodiment_ids, padding_mask)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
AutoModel.register(ActionCodecConfig, ActionCodec)
|
| 421 |
+
|
| 422 |
+
__all__ = ["ActionCodec"]
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
if __name__ == "__main__":
|
| 426 |
+
print("=== ActionCodec Comprehensive Test ===\n")
|
| 427 |
+
|
| 428 |
+
# 1. Configuration Setup (RVQ enabled with n_quantizers=4)
|
| 429 |
+
initial_config = {
|
| 430 |
+
"robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"},
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
# We set n_quantizers=4 to test Residual VQ logic
|
| 434 |
+
config = ActionCodecConfig(
|
| 435 |
+
embodiment_config=initial_config,
|
| 436 |
+
n_tokens=16, # Total tokens per sequence (latent_len * n_quantizers)
|
| 437 |
+
n_quantizers=4, # RVQ depth
|
| 438 |
+
vq_type="rvq",
|
| 439 |
+
vq_codebook_size=256,
|
| 440 |
+
encoder_dim=128,
|
| 441 |
+
decoder_dim=128,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Expected latent sequence length = n_tokens / n_quantizers = 16 / 4 = 4
|
| 445 |
+
latent_seq_len = int(config.n_tokens // config.n_quantizers)
|
| 446 |
+
print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.")
|
| 447 |
+
|
| 448 |
+
codec = ActionCodec(config)
|
| 449 |
+
codec.eval()
|
| 450 |
+
|
| 451 |
+
# 2. Basic Encode/Decode Test
|
| 452 |
+
print("\n--- Test 1: Basic Encode/Decode ---")
|
| 453 |
+
batch_size = 2
|
| 454 |
+
seq_len_A = 10 # 10Hz * 1s
|
| 455 |
+
|
| 456 |
+
# Create random action data for Robot A (ID 0)
|
| 457 |
+
x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32)
|
| 458 |
+
# Masking: Second item in batch is half padding
|
| 459 |
+
padding_mask = np.ones((batch_size, seq_len_A), dtype=bool)
|
| 460 |
+
padding_mask[1, 5:] = False
|
| 461 |
+
|
| 462 |
+
embodiment_ids = [0, 0]
|
| 463 |
+
|
| 464 |
+
# Encode
|
| 465 |
+
codes = codec.encode(x, embodiment_ids, padding_mask)
|
| 466 |
+
print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}")
|
| 467 |
+
|
| 468 |
+
# Validate code length
|
| 469 |
+
assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}"
|
| 470 |
+
|
| 471 |
+
# Decode
|
| 472 |
+
x_recon, recon_mask = codec.decode(codes, embodiment_ids)
|
| 473 |
+
print(f"Reconstructed shape: {x_recon.shape}")
|
| 474 |
+
print(f"Recon mask shape: {recon_mask.shape}")
|
| 475 |
+
|
| 476 |
+
assert x_recon.shape == (batch_size, seq_len_A, 7) # Should imply zero-padding to max dim 7
|
| 477 |
+
|
| 478 |
+
# 3. Expansion Test
|
| 479 |
+
print("\n--- Test 2: Dynamic Expansion ---")
|
| 480 |
+
new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}}
|
| 481 |
+
|
| 482 |
+
print("Expanding codec to include Robot B (10 dims, 20Hz)...")
|
| 483 |
+
codec.expand_embodiment(new_robot_config)
|
| 484 |
+
|
| 485 |
+
assert codec.encoder.max_action_dim == 10
|
| 486 |
+
assert codec.decoder.max_action_dim == 10
|
| 487 |
+
print("✅ Expansion successful.")
|
| 488 |
+
|
| 489 |
+
# 4. Mixed Batch Test (Old + New Robot)
|
| 490 |
+
print("\n--- Test 3: Mixed Batch Inference ---")
|
| 491 |
+
|
| 492 |
+
# Batch: [Robot A, Robot B]
|
| 493 |
+
# Robot A: 10Hz, 1s -> 10 steps. Dims 7.
|
| 494 |
+
# Robot B: 20Hz, 1s -> 20 steps. Dims 10.
|
| 495 |
+
# Batch Max Steps: 20. Batch Max Dims: 10.
|
| 496 |
+
|
| 497 |
+
batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32)
|
| 498 |
+
|
| 499 |
+
# Fill Robot A data (index 0)
|
| 500 |
+
data_A = np.random.randn(10, 7)
|
| 501 |
+
batch_x_mixed[0, :10, :7] = data_A
|
| 502 |
+
|
| 503 |
+
# Fill Robot B data (index 1)
|
| 504 |
+
data_B = np.random.randn(20, 10)
|
| 505 |
+
batch_x_mixed[1, :20, :10] = data_B
|
| 506 |
+
|
| 507 |
+
# Embodiment IDs: 0 for A, 1 for B
|
| 508 |
+
# Note: expand_embodiment appends. Original was 0, new is 1.
|
| 509 |
+
mixed_ids = [0, 1]
|
| 510 |
+
|
| 511 |
+
# Encode Mask
|
| 512 |
+
mixed_mask = np.zeros((2, 20), dtype=bool)
|
| 513 |
+
mixed_mask[0, :10] = True
|
| 514 |
+
mixed_mask[1, :20] = True
|
| 515 |
+
|
| 516 |
+
print("Encoding mixed batch...")
|
| 517 |
+
mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask)
|
| 518 |
+
|
| 519 |
+
print("Decoding mixed batch...")
|
| 520 |
+
# Explicit durations (optional, but good for verification if we wanted to override defaults)
|
| 521 |
+
durations = [1, 1]
|
| 522 |
+
x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations)
|
| 523 |
+
|
| 524 |
+
print(f"Mixed Recon Shape: {x_recon_mixed.shape}")
|
| 525 |
+
|
| 526 |
+
# Validation
|
| 527 |
+
# Robot A output check (mask should be True for first 10, False for rest)
|
| 528 |
+
valid_A = dec_mask_mixed[0].sum()
|
| 529 |
+
valid_B = dec_mask_mixed[1].sum()
|
| 530 |
+
|
| 531 |
+
print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}")
|
| 532 |
+
|
| 533 |
+
assert valid_A == 10
|
| 534 |
+
assert valid_B == 20
|
| 535 |
+
|
| 536 |
+
# Check dimensionality preservation
|
| 537 |
+
# Robot A's reconstruction in dims 7-9 should be noise or zero (depending on implementation),
|
| 538 |
+
# but dims 0-6 should contain signal.
|
| 539 |
+
print("✅ Mixed batch processed successfully.")
|
| 540 |
+
|
| 541 |
+
print("\n✨ All systems go.")
|
modular_actioncodec.py
ADDED
|
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from typing import List, Literal, Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import einops
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from .configuration_actioncodec import ActionCodecConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def apply_rotary_pos_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
original_dtype = x.dtype
|
| 16 |
+
|
| 17 |
+
x = x.to(torch.float32)
|
| 18 |
+
sin = sin.to(torch.float32)
|
| 19 |
+
cos = cos.to(torch.float32)
|
| 20 |
+
|
| 21 |
+
x1 = x[..., 0::2]
|
| 22 |
+
x2 = x[..., 1::2]
|
| 23 |
+
|
| 24 |
+
rotated_x1 = x1 * cos - x2 * sin
|
| 25 |
+
rotated_x2 = x1 * sin + x2 * cos
|
| 26 |
+
|
| 27 |
+
x_out = torch.empty_like(x)
|
| 28 |
+
x_out[..., 0::2] = rotated_x1
|
| 29 |
+
x_out[..., 1::2] = rotated_x2
|
| 30 |
+
|
| 31 |
+
return x_out.to(original_dtype)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def attention_op(
|
| 35 |
+
q: torch.Tensor,
|
| 36 |
+
k: torch.Tensor,
|
| 37 |
+
v: torch.Tensor,
|
| 38 |
+
mask: torch.Tensor | None = None,
|
| 39 |
+
is_causal: bool = False,
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
q (torch.Tensor): (*b, h, l, d)
|
| 45 |
+
k (torch.Tensor): (*b, k, s, d)
|
| 46 |
+
v (torch.Tensor): (*b, k, s, d)
|
| 47 |
+
mask (torch.Tensor | None, optional): (*b, l, s), where `True` indicates the element should take part in attention. Defaults to None.
|
| 48 |
+
is_causal (bool, optional): Whether to apply causal mask. Defaults to False.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
torch.Tensor: (*b, h, l, d)
|
| 52 |
+
"""
|
| 53 |
+
heads, kv_heads = q.shape[-3], k.shape[-3]
|
| 54 |
+
if heads != kv_heads:
|
| 55 |
+
assert heads % kv_heads == 0, f"q_heads must be divisible by kv_heads, but got {heads} and {kv_heads}"
|
| 56 |
+
heads_per_kv_head = heads // kv_heads
|
| 57 |
+
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
|
| 58 |
+
|
| 59 |
+
if mask is not None:
|
| 60 |
+
if mask.dim() == 3:
|
| 61 |
+
mask = mask.unsqueeze(1)
|
| 62 |
+
mask = mask.expand(mask.shape[0], heads, -1, -1)
|
| 63 |
+
|
| 64 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=is_causal)
|
| 65 |
+
return out
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class L2Norm(nn.Module):
|
| 69 |
+
def forward(self, x: torch.Tensor):
|
| 70 |
+
return F.normalize(x, p=2, dim=-1)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Attention(nn.Module):
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
hidden_size (int): Hidden size of the input tensor.
|
| 77 |
+
num_heads (int): Number of attention heads.
|
| 78 |
+
num_kv_heads (int, optional): Number of key/value heads. Defaults to None.
|
| 79 |
+
qk_norm (Literal["l2", "ln", "none"], optional): Type of normalization to apply to query/key. Defaults to "none".
|
| 80 |
+
bias (bool, optional): Whether to use bias in linear layers. Defaults to False.
|
| 81 |
+
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
hidden_size: int,
|
| 87 |
+
num_heads: int,
|
| 88 |
+
num_kv_heads: int | None = None,
|
| 89 |
+
qk_norm: Literal["l2", "ln", "none"] = "none",
|
| 90 |
+
bias: bool = False,
|
| 91 |
+
zero_init_output: bool = False,
|
| 92 |
+
):
|
| 93 |
+
super().__init__()
|
| 94 |
+
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
| 95 |
+
self.dim = hidden_size // num_heads
|
| 96 |
+
self.num_heads, self.num_kv_heads = num_heads, num_kv_heads
|
| 97 |
+
|
| 98 |
+
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
| 99 |
+
self.k_proj = nn.Linear(hidden_size, self.dim * num_kv_heads, bias=bias)
|
| 100 |
+
self.v_proj = nn.Linear(hidden_size, self.dim * num_kv_heads, bias=bias)
|
| 101 |
+
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
|
| 102 |
+
|
| 103 |
+
if qk_norm == "l2":
|
| 104 |
+
self.q_norm = L2Norm()
|
| 105 |
+
self.k_norm = L2Norm()
|
| 106 |
+
elif qk_norm == "ln":
|
| 107 |
+
self.q_norm = nn.LayerNorm(self.dim, elementwise_affine=False)
|
| 108 |
+
self.k_norm = nn.LayerNorm(self.dim, elementwise_affine=False)
|
| 109 |
+
else:
|
| 110 |
+
self.q_norm = nn.Identity()
|
| 111 |
+
self.k_norm = nn.Identity()
|
| 112 |
+
|
| 113 |
+
if zero_init_output:
|
| 114 |
+
nn.init.zeros_(self.out_proj.weight)
|
| 115 |
+
if self.out_proj.bias is not None:
|
| 116 |
+
nn.init.zeros_(self.out_proj.bias)
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
x: torch.Tensor,
|
| 121 |
+
context: torch.Tensor | None = None,
|
| 122 |
+
mask: torch.Tensor | None = None,
|
| 123 |
+
rotary_pos_emb: Tuple[torch.Tensor, torch.Tensor] | None = None,
|
| 124 |
+
is_causal: bool = False,
|
| 125 |
+
) -> torch.Tensor:
|
| 126 |
+
context = x if context is None else context
|
| 127 |
+
|
| 128 |
+
q = self.q_proj(x)
|
| 129 |
+
k, v = self.k_proj(context), self.v_proj(context)
|
| 130 |
+
|
| 131 |
+
q = einops.rearrange(q, "b l (h d) -> b h l d", h=self.num_heads)
|
| 132 |
+
k = einops.rearrange(k, "b s (h d) -> b h s d", h=self.num_kv_heads)
|
| 133 |
+
v = einops.rearrange(v, "b s (h d) -> b h s d", h=self.num_kv_heads)
|
| 134 |
+
|
| 135 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 136 |
+
|
| 137 |
+
if rotary_pos_emb is not None:
|
| 138 |
+
q, k = map(lambda t: apply_rotary_pos_emb(t, *rotary_pos_emb), (q, k))
|
| 139 |
+
|
| 140 |
+
out = attention_op(q, k, v, mask=mask, is_causal=is_causal)
|
| 141 |
+
out = einops.rearrange(out, "b h l d -> b l (h d)")
|
| 142 |
+
out = self.out_proj(out)
|
| 143 |
+
|
| 144 |
+
return out
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class PositionalEmbedding(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
dim: int,
|
| 151 |
+
encoding_type: Literal["sincos", "fourier"] = "sincos",
|
| 152 |
+
scale: float = 2.0,
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.dim = dim
|
| 156 |
+
self.encoding_type = encoding_type
|
| 157 |
+
|
| 158 |
+
if encoding_type == "fourier":
|
| 159 |
+
self.register_buffer("freqs", torch.randn(dim // 2) * scale, persistent=True)
|
| 160 |
+
elif encoding_type == "sincos":
|
| 161 |
+
pass
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError(f"encoding_type must be 'sincos' or 'fourier', but got {encoding_type}")
|
| 164 |
+
|
| 165 |
+
def _create_sincos_emb(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| 166 |
+
position = torch.arange(seq_len, device=device, dtype=torch.float32).unsqueeze(1)
|
| 167 |
+
div_term = torch.exp(
|
| 168 |
+
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) * -(math.log(10000.0) / self.dim)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
pos_emb = torch.zeros(seq_len, self.dim, device=device, dtype=dtype)
|
| 172 |
+
pos_emb[:, 0::2] = torch.sin(position * div_term).to(dtype)
|
| 173 |
+
pos_emb[:, 1::2] = torch.cos(position * div_term).to(dtype)
|
| 174 |
+
|
| 175 |
+
return pos_emb
|
| 176 |
+
|
| 177 |
+
def _create_fourier_emb(self, timestamps: torch.Tensor, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| 178 |
+
pos_emb = torch.einsum("b t, d -> b t d", timestamps, 2 * np.pi * self.freqs).to(device, torch.float32)
|
| 179 |
+
pos_emb = torch.cat([pos_emb.cos(), pos_emb.sin()], dim=-1).to(dtype)
|
| 180 |
+
return pos_emb
|
| 181 |
+
|
| 182 |
+
def forward(
|
| 183 |
+
self, x: torch.Tensor, freq: Optional[Union[float, torch.Tensor]] = None, dtype: torch.dtype = torch.float32
|
| 184 |
+
) -> torch.Tensor:
|
| 185 |
+
b, t = x.shape[0], x.shape[1]
|
| 186 |
+
device = x.device
|
| 187 |
+
|
| 188 |
+
if self.encoding_type == "sincos":
|
| 189 |
+
pos_emb = self._create_sincos_emb(t, device, dtype)
|
| 190 |
+
pos_emb = pos_emb.unsqueeze(0).expand(b, -1, -1)
|
| 191 |
+
return pos_emb * 0.1
|
| 192 |
+
|
| 193 |
+
elif self.encoding_type == "fourier":
|
| 194 |
+
if freq is None:
|
| 195 |
+
raise ValueError(
|
| 196 |
+
"freq must be provided when encoding_type is 'fourier'. Please provide the sequence frequency."
|
| 197 |
+
)
|
| 198 |
+
if isinstance(freq, float):
|
| 199 |
+
freq = torch.tensor(freq, dtype=dtype, device=device)[None].expand(b)
|
| 200 |
+
timestamps = torch.einsum("t, b -> b t", torch.arange(t, dtype=dtype, device=device), 1 / freq)
|
| 201 |
+
pos_emb = self._create_fourier_emb(timestamps, device, dtype)
|
| 202 |
+
return pos_emb * 0.1
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError(f"Unknown encoding_type: {self.encoding_type}")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class SinusoidalPositionalEmbedding(PositionalEmbedding):
|
| 208 |
+
def __init__(self, dim: int):
|
| 209 |
+
super().__init__(dim=dim, encoding_type="sincos")
|
| 210 |
+
|
| 211 |
+
def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 212 |
+
return super().forward(x, freq=None)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class FeedForward(nn.Module):
|
| 216 |
+
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
| 219 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
| 220 |
+
self.act_fn = nn.GELU()
|
| 221 |
+
|
| 222 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 223 |
+
down_proj = self.down_proj(self.act_fn(self.up_proj(x)))
|
| 224 |
+
return down_proj
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class LayerScale(nn.Module):
|
| 228 |
+
def __init__(self, dim, init_val=1e-2):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.scale = nn.Parameter(torch.full([dim], init_val))
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
return x * self.scale
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class PerceiverTransformerBlock(nn.Module):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
dim: int,
|
| 240 |
+
num_heads: int,
|
| 241 |
+
mlp_ratio: int = 4,
|
| 242 |
+
dropout: float = 0.0,
|
| 243 |
+
qk_norm: str = "ln",
|
| 244 |
+
layer_scale: bool = True,
|
| 245 |
+
zero_init_output: bool = False,
|
| 246 |
+
add_self_attn: bool = False,
|
| 247 |
+
add_causal_mask: bool = False,
|
| 248 |
+
):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.add_self_attn = add_self_attn
|
| 251 |
+
self.add_causal_mask = add_causal_mask
|
| 252 |
+
|
| 253 |
+
self.norm1 = nn.LayerNorm(dim, eps=1e-2)
|
| 254 |
+
self.cross_attn = Attention(
|
| 255 |
+
hidden_size=dim, num_heads=num_heads, qk_norm=qk_norm, bias=False, zero_init_output=zero_init_output
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if add_self_attn:
|
| 259 |
+
self.norm_self_attn = nn.LayerNorm(dim, eps=1e-2)
|
| 260 |
+
self.self_attn = Attention(
|
| 261 |
+
hidden_size=dim, num_heads=num_heads, qk_norm=qk_norm, bias=False, zero_init_output=zero_init_output
|
| 262 |
+
)
|
| 263 |
+
else:
|
| 264 |
+
self.self_attn = None
|
| 265 |
+
|
| 266 |
+
self.norm2 = nn.LayerNorm(dim, eps=1e-2)
|
| 267 |
+
self.mlp = FeedForward(hidden_size=dim, intermediate_size=int(mlp_ratio * dim), bias=True)
|
| 268 |
+
self.dropout = nn.Dropout(dropout)
|
| 269 |
+
|
| 270 |
+
self.attn_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
| 271 |
+
self.mlp_scale = LayerScale(dim) if layer_scale else nn.Identity()
|
| 272 |
+
|
| 273 |
+
if zero_init_output:
|
| 274 |
+
nn.init.zeros_(self.mlp.down_proj.weight)
|
| 275 |
+
if self.mlp.down_proj.bias is not None:
|
| 276 |
+
nn.init.zeros_(self.mlp.down_proj.bias)
|
| 277 |
+
|
| 278 |
+
def forward(
|
| 279 |
+
self,
|
| 280 |
+
x: torch.Tensor,
|
| 281 |
+
context: torch.Tensor,
|
| 282 |
+
context_mask: Optional[torch.Tensor] = None,
|
| 283 |
+
rotary_pos_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 284 |
+
) -> torch.Tensor:
|
| 285 |
+
residual = x
|
| 286 |
+
x = self.norm1(x)
|
| 287 |
+
x = self.cross_attn(x=x, context=context, mask=context_mask, rotary_pos_emb=rotary_pos_emb, is_causal=False)
|
| 288 |
+
x = self.dropout(x)
|
| 289 |
+
x = self.attn_scale(x)
|
| 290 |
+
x = x + residual
|
| 291 |
+
|
| 292 |
+
if self.add_self_attn:
|
| 293 |
+
residual = x
|
| 294 |
+
x = self.norm_self_attn(x)
|
| 295 |
+
x = self.self_attn(
|
| 296 |
+
x=x,
|
| 297 |
+
context=None,
|
| 298 |
+
mask=None,
|
| 299 |
+
rotary_pos_emb=rotary_pos_emb,
|
| 300 |
+
is_causal=self.add_causal_mask,
|
| 301 |
+
)
|
| 302 |
+
x = self.dropout(x)
|
| 303 |
+
x = self.attn_scale(x)
|
| 304 |
+
x = x + residual
|
| 305 |
+
|
| 306 |
+
residual = x
|
| 307 |
+
x = self.norm2(x)
|
| 308 |
+
x = self.mlp(x)
|
| 309 |
+
x = self.dropout(x)
|
| 310 |
+
x = self.mlp_scale(x)
|
| 311 |
+
x = x + residual
|
| 312 |
+
|
| 313 |
+
return x
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class EmbodimentEmbedding(nn.Module):
|
| 317 |
+
def __init__(self, embodiment_config: dict, out_len: int, out_dim: int) -> None:
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.out_len, self.out_dim = out_len, out_dim
|
| 320 |
+
|
| 321 |
+
self.embodiment_config = embodiment_config
|
| 322 |
+
self.num_embodiments = len(self.embodiment_config)
|
| 323 |
+
|
| 324 |
+
self.embedding = nn.Embedding(self.num_embodiments, out_dim * out_len)
|
| 325 |
+
|
| 326 |
+
@torch.no_grad()
|
| 327 |
+
def expand_embodiment(self, embodiment_config: dict):
|
| 328 |
+
for k in embodiment_config.keys():
|
| 329 |
+
assert k not in self.embodiment_config.keys()
|
| 330 |
+
self.embodiment_config.update(embodiment_config)
|
| 331 |
+
self.num_embodiments = len(self.embodiment_config)
|
| 332 |
+
|
| 333 |
+
extra_embodiments = len(embodiment_config)
|
| 334 |
+
|
| 335 |
+
old_weights = torch.clone(self.embedding.weight)
|
| 336 |
+
self.embedding = nn.Embedding(self.num_embodiments, self.out_dim * self.out_len)
|
| 337 |
+
self.embedding.weight.data[:-extra_embodiments] = old_weights
|
| 338 |
+
return self
|
| 339 |
+
|
| 340 |
+
def keys(self) -> list[str]:
|
| 341 |
+
return list(self.embodiment_config.keys())
|
| 342 |
+
|
| 343 |
+
def ids_to_keys(self, ids: torch.Tensor) -> List[str]:
|
| 344 |
+
return [self.keys()[i] for i in ids]
|
| 345 |
+
|
| 346 |
+
def keys_to_ids(self, keys: List[str]) -> torch.Tensor:
|
| 347 |
+
return torch.tensor([self.keys().index(k) for k in keys])
|
| 348 |
+
|
| 349 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 350 |
+
return einops.rearrange(self.embedding(x), "b (l d) -> b l d", d=self.out_dim)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class PerceiverEncoder(nn.Module):
|
| 354 |
+
def __init__(self, config: ActionCodecConfig):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.config = config
|
| 357 |
+
self.embodiment_config = deepcopy(config.embodiment_config)
|
| 358 |
+
|
| 359 |
+
out_len = int(config.n_tokens // config.n_quantizers)
|
| 360 |
+
dim = config.encoder_dim
|
| 361 |
+
|
| 362 |
+
_action_dim, _freq, _duration = list(), list(), list()
|
| 363 |
+
for k, v in self.embodiment_config.items():
|
| 364 |
+
_action_dim.append(v["action_dim"])
|
| 365 |
+
_freq.append(v["freq"])
|
| 366 |
+
_duration.append(v["duration"])
|
| 367 |
+
self.register_buffer("_action_dim", torch.tensor(_action_dim), persistent=False)
|
| 368 |
+
self.register_buffer("_freq", torch.tensor(_freq), persistent=False)
|
| 369 |
+
self.register_buffer("_duration", torch.tensor(_duration), persistent=False)
|
| 370 |
+
|
| 371 |
+
self.max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
|
| 372 |
+
self.input_proj = nn.Linear(self.max_action_dim, dim)
|
| 373 |
+
|
| 374 |
+
self.cls_tokens = EmbodimentEmbedding(self.embodiment_config, out_len, dim)
|
| 375 |
+
|
| 376 |
+
self.pos_emb_q = PositionalEmbedding(dim, encoding_type="sincos")
|
| 377 |
+
self.pos_emb_kv = PositionalEmbedding(dim, encoding_type=config.encoder_pos_encoding_type)
|
| 378 |
+
|
| 379 |
+
self.layers = nn.ModuleList(
|
| 380 |
+
[
|
| 381 |
+
PerceiverTransformerBlock(
|
| 382 |
+
dim=dim,
|
| 383 |
+
num_heads=config.encoder_n_heads,
|
| 384 |
+
add_self_attn=config.encoder_add_self_attn,
|
| 385 |
+
add_causal_mask=config.encoder_add_causal_mask,
|
| 386 |
+
)
|
| 387 |
+
for _ in range(config.encoder_n_layers)
|
| 388 |
+
]
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
self.output_proj = nn.Linear(dim, config.z_dim)
|
| 392 |
+
self._init_weights()
|
| 393 |
+
|
| 394 |
+
def _init_weights(self):
|
| 395 |
+
nn.init.trunc_normal_(self.input_proj.weight, std=0.02)
|
| 396 |
+
if self.input_proj.bias is not None:
|
| 397 |
+
nn.init.zeros_(self.input_proj.bias)
|
| 398 |
+
nn.init.trunc_normal_(self.output_proj.weight, std=0.02)
|
| 399 |
+
if self.output_proj.bias is not None:
|
| 400 |
+
nn.init.zeros_(self.output_proj.bias)
|
| 401 |
+
|
| 402 |
+
nn.init.trunc_normal_(self.cls_tokens.embedding.weight, std=0.02)
|
| 403 |
+
|
| 404 |
+
@torch.no_grad()
|
| 405 |
+
def expand_embodiment(self, embodiment_config: dict):
|
| 406 |
+
self.cls_tokens.expand_embodiment(embodiment_config)
|
| 407 |
+
self.embodiment_config = self.cls_tokens.embodiment_config
|
| 408 |
+
_action_dim, _freq, _duration = list(), list(), list()
|
| 409 |
+
for k, v in self.embodiment_config.items():
|
| 410 |
+
_action_dim.append(v["action_dim"])
|
| 411 |
+
_freq.append(v["freq"])
|
| 412 |
+
_duration.append(v["duration"])
|
| 413 |
+
self._action_dim = torch.tensor(_action_dim)
|
| 414 |
+
self._freq = torch.tensor(_freq)
|
| 415 |
+
self._duration = torch.tensor(_duration)
|
| 416 |
+
|
| 417 |
+
max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
|
| 418 |
+
if max_action_dim > self.max_action_dim:
|
| 419 |
+
old_weights = torch.clone(self.input_proj.weight)
|
| 420 |
+
old_bias = torch.clone(self.input_proj.bias)
|
| 421 |
+
self.input_proj = nn.Linear(max_action_dim, self.config.encoder_dim)
|
| 422 |
+
self.input_proj.weight.data[:, : self.max_action_dim] = old_weights
|
| 423 |
+
self.input_proj.bias.data = old_bias
|
| 424 |
+
self.max_action_dim = max_action_dim
|
| 425 |
+
|
| 426 |
+
return self
|
| 427 |
+
|
| 428 |
+
def forward(
|
| 429 |
+
self,
|
| 430 |
+
x: torch.Tensor,
|
| 431 |
+
embodiment_ids: torch.Tensor | int,
|
| 432 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 433 |
+
) -> torch.Tensor:
|
| 434 |
+
"""Encode action sequences into latent representations.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim).
|
| 438 |
+
Assumes that the action dimension is zero-padded to the max action dimension.
|
| 439 |
+
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
|
| 440 |
+
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
|
| 441 |
+
If int, the same embodiment ID is repeated for all sequences in the batch.
|
| 442 |
+
It specifies the embodiment to encode.
|
| 443 |
+
padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None.
|
| 444 |
+
It is used to mask the padding tokens on `seq_len` dimension.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 448 |
+
"""
|
| 449 |
+
b, seq_len, _ = x.shape
|
| 450 |
+
|
| 451 |
+
x = self.input_proj(x)
|
| 452 |
+
|
| 453 |
+
if isinstance(embodiment_ids, int):
|
| 454 |
+
embodiment_ids = torch.tensor([embodiment_ids], dtype=torch.long, device=x.device).repeat(b)
|
| 455 |
+
|
| 456 |
+
cls_tokens = self.cls_tokens(embodiment_ids)
|
| 457 |
+
|
| 458 |
+
freqs = self._freq[embodiment_ids].to(x.device, x.dtype)
|
| 459 |
+
|
| 460 |
+
pos_emb_q = self.pos_emb_q(cls_tokens)
|
| 461 |
+
pos_emb_kv = self.pos_emb_kv(x, freqs)
|
| 462 |
+
|
| 463 |
+
cls_tokens = cls_tokens + pos_emb_q
|
| 464 |
+
x = x + pos_emb_kv
|
| 465 |
+
|
| 466 |
+
if padding_mask is not None:
|
| 467 |
+
padding_mask = padding_mask.unsqueeze(1).expand(-1, cls_tokens.shape[1], -1)
|
| 468 |
+
|
| 469 |
+
for layer in self.layers:
|
| 470 |
+
cls_tokens = layer(x=cls_tokens, context=x, context_mask=padding_mask)
|
| 471 |
+
|
| 472 |
+
return self.output_proj(cls_tokens)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
class PerceiverDecoder(nn.Module):
|
| 476 |
+
def __init__(self, config: ActionCodecConfig):
|
| 477 |
+
super().__init__()
|
| 478 |
+
self.config = config
|
| 479 |
+
self.embodiment_config = deepcopy(config.embodiment_config)
|
| 480 |
+
|
| 481 |
+
dim = config.decoder_dim
|
| 482 |
+
|
| 483 |
+
_action_dim, _freq, _duration = list(), list(), list()
|
| 484 |
+
for k, v in self.embodiment_config.items():
|
| 485 |
+
_action_dim.append(v["action_dim"])
|
| 486 |
+
_freq.append(v["freq"])
|
| 487 |
+
_duration.append(v["duration"])
|
| 488 |
+
self.register_buffer("_action_dim", torch.tensor(_action_dim), persistent=False)
|
| 489 |
+
self.register_buffer("_freq", torch.tensor(_freq), persistent=False)
|
| 490 |
+
self.register_buffer("_duration", torch.tensor(_duration), persistent=False)
|
| 491 |
+
|
| 492 |
+
self.max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
|
| 493 |
+
self.input_proj = nn.Linear(config.z_dim, dim)
|
| 494 |
+
|
| 495 |
+
self.cls_tokens = EmbodimentEmbedding(self.embodiment_config, config.decoder_cls_size, dim)
|
| 496 |
+
|
| 497 |
+
self.pos_emb_q = PositionalEmbedding(dim, encoding_type=config.decoder_pos_encoding_type)
|
| 498 |
+
self.pos_emb_kv = PositionalEmbedding(dim, encoding_type="sincos")
|
| 499 |
+
|
| 500 |
+
self.layers = nn.ModuleList(
|
| 501 |
+
[
|
| 502 |
+
PerceiverTransformerBlock(
|
| 503 |
+
dim=dim,
|
| 504 |
+
num_heads=config.decoder_n_heads,
|
| 505 |
+
add_self_attn=config.decoder_add_self_attn,
|
| 506 |
+
add_causal_mask=config.decoder_add_causal_mask,
|
| 507 |
+
)
|
| 508 |
+
for _ in range(config.decoder_n_layers)
|
| 509 |
+
]
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
self.output_proj = nn.Linear(dim, self.max_action_dim)
|
| 513 |
+
self._init_weights()
|
| 514 |
+
|
| 515 |
+
def _init_weights(self):
|
| 516 |
+
nn.init.trunc_normal_(self.input_proj.weight, std=0.02)
|
| 517 |
+
if self.input_proj.bias is not None:
|
| 518 |
+
nn.init.zeros_(self.input_proj.bias)
|
| 519 |
+
nn.init.trunc_normal_(self.output_proj.weight, std=0.02)
|
| 520 |
+
if self.output_proj.bias is not None:
|
| 521 |
+
nn.init.zeros_(self.output_proj.bias)
|
| 522 |
+
nn.init.trunc_normal_(self.cls_tokens.embedding.weight, std=0.02)
|
| 523 |
+
|
| 524 |
+
@torch.no_grad()
|
| 525 |
+
def expand_embodiment(self, embodiment_config: dict):
|
| 526 |
+
self.cls_tokens.expand_embodiment(embodiment_config)
|
| 527 |
+
self.embodiment_config = self.cls_tokens.embodiment_config
|
| 528 |
+
|
| 529 |
+
_action_dim, _freq, _duration = list(), list(), list()
|
| 530 |
+
for k, v in self.embodiment_config.items():
|
| 531 |
+
_action_dim.append(v["action_dim"])
|
| 532 |
+
_freq.append(v["freq"])
|
| 533 |
+
_duration.append(v["duration"])
|
| 534 |
+
self._action_dim = torch.tensor(_action_dim)
|
| 535 |
+
self._freq = torch.tensor(_freq)
|
| 536 |
+
self._duration = torch.tensor(_duration)
|
| 537 |
+
|
| 538 |
+
max_action_dim = max(v["action_dim"] for v in self.embodiment_config.values())
|
| 539 |
+
|
| 540 |
+
if max_action_dim > self.max_action_dim:
|
| 541 |
+
old_weights = torch.clone(self.output_proj.weight)
|
| 542 |
+
old_bias = torch.clone(self.output_proj.bias)
|
| 543 |
+
|
| 544 |
+
self.output_proj = nn.Linear(self.config.decoder_dim, max_action_dim)
|
| 545 |
+
|
| 546 |
+
self.output_proj.weight.data[: self.max_action_dim, :] = old_weights
|
| 547 |
+
self.output_proj.bias.data[: self.max_action_dim] = old_bias
|
| 548 |
+
|
| 549 |
+
self.max_action_dim = max_action_dim
|
| 550 |
+
|
| 551 |
+
return self
|
| 552 |
+
|
| 553 |
+
def forward(
|
| 554 |
+
self, x: torch.Tensor, embodiment_ids: torch.Tensor | int, durations: torch.Tensor | None = None
|
| 555 |
+
) -> torch.Tensor:
|
| 556 |
+
"""Decode latent representations into action sequences.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
x (torch.Tensor): Latent representations to decode. Shape: (b, n_tokens_per_quantizer, z_dim).
|
| 560 |
+
embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,).
|
| 561 |
+
If int, the same embodiment ID is repeated for all sequences in the batch.
|
| 562 |
+
It specifies the embodiment to decode.
|
| 563 |
+
durations (torch.Tensor | None, optional): Duration of each action sequence. Shape: (b,).
|
| 564 |
+
If `None`, the duration is inferred from the default values in `embodiment_config`.
|
| 565 |
+
|
| 566 |
+
Returns:
|
| 567 |
+
torch.Tensor: Decoded action sequences. Shape: (b, seq_len, max_action_dim).
|
| 568 |
+
Assumes that the action dimension is zero-padded to the max action dimension.
|
| 569 |
+
`seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length.
|
| 570 |
+
"""
|
| 571 |
+
b, seq_len, _ = x.shape
|
| 572 |
+
x = self.input_proj(x)
|
| 573 |
+
|
| 574 |
+
if isinstance(embodiment_ids, int):
|
| 575 |
+
embodiment_ids = torch.tensor([embodiment_ids], dtype=torch.long, device=x.device).repeat(b)
|
| 576 |
+
|
| 577 |
+
cls_tokens = self.cls_tokens(embodiment_ids)
|
| 578 |
+
|
| 579 |
+
freqs = self._freq[embodiment_ids]
|
| 580 |
+
durations = self._duration[embodiment_ids] if durations is None else durations
|
| 581 |
+
action_horizons = (durations * freqs).long()
|
| 582 |
+
max_horizon = action_horizons.max().item()
|
| 583 |
+
padding_mask = torch.arange(max_horizon, device=x.device).expand(b, -1) < action_horizons.unsqueeze(1)
|
| 584 |
+
|
| 585 |
+
if self.config.decoder_cls_size == 1:
|
| 586 |
+
cls_tokens = cls_tokens.repeat(1, max_horizon, 1)
|
| 587 |
+
|
| 588 |
+
pos_emb_q = self.pos_emb_q(cls_tokens, freqs)
|
| 589 |
+
pos_emb_kv = self.pos_emb_kv(x)
|
| 590 |
+
|
| 591 |
+
cls_tokens = cls_tokens + pos_emb_q
|
| 592 |
+
x = x + pos_emb_kv
|
| 593 |
+
|
| 594 |
+
for layer in self.layers:
|
| 595 |
+
cls_tokens = layer(x=cls_tokens, context=x)
|
| 596 |
+
|
| 597 |
+
output = self.output_proj(cls_tokens)
|
| 598 |
+
|
| 599 |
+
return output, padding_mask
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
if __name__ == "__main__":
|
| 603 |
+
# ------------------------------------------
|
| 604 |
+
# 1. Initialization
|
| 605 |
+
# ------------------------------------------
|
| 606 |
+
print("=== Test 1: Initialization ===")
|
| 607 |
+
|
| 608 |
+
# Define initial config with two smaller robots
|
| 609 |
+
initial_embodiment_config = {
|
| 610 |
+
"robot_small_7d": {"action_dim": 7, "freq": 20, "duration": 1, "description": "Original Robot"},
|
| 611 |
+
"robot_tiny_3d": {"action_dim": 3, "freq": 10, "duration": 2, "description": "Tiny Robot"},
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
config = ActionCodecConfig(embodiment_config=initial_embodiment_config)
|
| 615 |
+
|
| 616 |
+
# Set seed for reproducibility
|
| 617 |
+
torch.manual_seed(42)
|
| 618 |
+
|
| 619 |
+
encoder = PerceiverEncoder(config)
|
| 620 |
+
decoder = PerceiverDecoder(config)
|
| 621 |
+
|
| 622 |
+
encoder.eval()
|
| 623 |
+
decoder.eval()
|
| 624 |
+
print("✅ Models initialized successfully.")
|
| 625 |
+
|
| 626 |
+
# ------------------------------------------
|
| 627 |
+
# 2. Baseline Inference (Before Expansion)
|
| 628 |
+
# ------------------------------------------
|
| 629 |
+
print("\n=== Test 2: Baseline Inference (Before Expansion) ===")
|
| 630 |
+
|
| 631 |
+
# Simulate Robot 1 (7-dim) data
|
| 632 |
+
# Max action dim currently is 7.
|
| 633 |
+
batch_size = 1
|
| 634 |
+
seq_len = 20 # 20Hz * 1s
|
| 635 |
+
|
| 636 |
+
# Input: (1, 20, 7)
|
| 637 |
+
input_action_v0 = torch.randn(batch_size, seq_len, 7)
|
| 638 |
+
emb_id_v0 = torch.tensor([0], dtype=torch.long) # ID 0 -> robot_small_7d
|
| 639 |
+
|
| 640 |
+
with torch.no_grad():
|
| 641 |
+
z_ref = encoder(input_action_v0, emb_id_v0)
|
| 642 |
+
rec_action_ref, _ = decoder(z_ref, emb_id_v0)
|
| 643 |
+
|
| 644 |
+
print(f"Reference Latent Shape: {z_ref.shape}")
|
| 645 |
+
print(f"Reference Recon Shape: {rec_action_ref.shape}")
|
| 646 |
+
|
| 647 |
+
# ------------------------------------------
|
| 648 |
+
# 3. Model Expansion (Add New Embodiment)
|
| 649 |
+
# ------------------------------------------
|
| 650 |
+
print("\n=== Test 3: Model Expansion ===")
|
| 651 |
+
|
| 652 |
+
# Add a larger robot: 10-dim, high frequency
|
| 653 |
+
new_embodiment_config = {
|
| 654 |
+
"robot_large_10d": {"action_dim": 10, "freq": 30, "duration": 1, "description": "New Large Robot"}
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
print(f"Expanding from Max Dim {encoder.max_action_dim} to 10...")
|
| 658 |
+
encoder.expand_embodiment(new_embodiment_config)
|
| 659 |
+
decoder.expand_embodiment(new_embodiment_config)
|
| 660 |
+
|
| 661 |
+
# Verify buffer updates
|
| 662 |
+
assert encoder._action_dim[-1] == 10
|
| 663 |
+
assert encoder.max_action_dim == 10
|
| 664 |
+
assert decoder.max_action_dim == 10
|
| 665 |
+
print(f"✅ Expansion successful. New Encoder Input Dim: {encoder.input_proj.weight.shape[1]}")
|
| 666 |
+
print(f"✅ New Decoder Output Dim: {decoder.output_proj.weight.shape[0]}")
|
| 667 |
+
|
| 668 |
+
# ------------------------------------------
|
| 669 |
+
# 4. Encoder Invariance Check
|
| 670 |
+
# ------------------------------------------
|
| 671 |
+
print("\n=== Test 4: Encoder Invariance Check ===")
|
| 672 |
+
|
| 673 |
+
# Pad old data (7 dims) to new max dim (10 dims) with ZEROS.
|
| 674 |
+
input_action_padded = torch.zeros(batch_size, seq_len, 10)
|
| 675 |
+
input_action_padded[:, :, :7] = input_action_v0
|
| 676 |
+
|
| 677 |
+
with torch.no_grad():
|
| 678 |
+
z_new = encoder(input_action_padded, emb_id_v0)
|
| 679 |
+
|
| 680 |
+
# Compare latents
|
| 681 |
+
diff_z = (z_ref - z_new).abs().max().item()
|
| 682 |
+
print(f"Latent Difference (Max Abs): {diff_z:.8f}")
|
| 683 |
+
|
| 684 |
+
if diff_z < 1e-6:
|
| 685 |
+
print("✅ PASS: Encoder produces identical latents for old data.")
|
| 686 |
+
else:
|
| 687 |
+
print("❌ FAIL: Encoder outputs changed after expansion!")
|
| 688 |
+
|
| 689 |
+
# ------------------------------------------
|
| 690 |
+
# 5. Decoder Invariance Check
|
| 691 |
+
# ------------------------------------------
|
| 692 |
+
print("\n=== Test 5: Decoder Invariance Check ===")
|
| 693 |
+
|
| 694 |
+
with torch.no_grad():
|
| 695 |
+
# Feed old latent to expanded decoder
|
| 696 |
+
rec_action_new_full, _ = decoder(z_ref, emb_id_v0)
|
| 697 |
+
|
| 698 |
+
# Output shape should be (1, 20, 10)
|
| 699 |
+
print(f"Expanded Decoder Output Shape: {rec_action_new_full.shape}")
|
| 700 |
+
|
| 701 |
+
# Slice first 7 dims, should match reference
|
| 702 |
+
rec_action_new_sliced = rec_action_new_full[:, :, :7]
|
| 703 |
+
|
| 704 |
+
diff_rec = (rec_action_ref - rec_action_new_sliced).abs().max().item()
|
| 705 |
+
print(f"Reconstruction Difference (Max Abs on valid dims): {diff_rec:.8f}")
|
| 706 |
+
|
| 707 |
+
if diff_rec < 1e-6:
|
| 708 |
+
print("✅ PASS: Decoder produces identical action values for valid dimensions.")
|
| 709 |
+
else:
|
| 710 |
+
print("❌ FAIL: Decoder outputs changed!")
|
| 711 |
+
|
| 712 |
+
# Check phantom dimensions (7-9)
|
| 713 |
+
# For old embodiment, these are driven by random weights and should be random
|
| 714 |
+
new_dims_mean = rec_action_new_full[:, :, 7:].abs().mean().item()
|
| 715 |
+
print(f"Values in new phantom dimensions (should be random garbage): {new_dims_mean:.4f}")
|
| 716 |
+
|
| 717 |
+
# ------------------------------------------
|
| 718 |
+
# 6. New Embodiment Inference
|
| 719 |
+
# ------------------------------------------
|
| 720 |
+
print("\n=== Test 6: New Embodiment Inference ===")
|
| 721 |
+
|
| 722 |
+
# ID 2 -> robot_large_10d
|
| 723 |
+
emb_id_new = torch.tensor([2], dtype=torch.long)
|
| 724 |
+
seq_len_new = 30 # 30Hz * 1s
|
| 725 |
+
|
| 726 |
+
input_action_new = torch.randn(1, seq_len_new, 10)
|
| 727 |
+
|
| 728 |
+
with torch.no_grad():
|
| 729 |
+
z_large = encoder(input_action_new, emb_id_new)
|
| 730 |
+
rec_large, mask_large = decoder(z_large, emb_id_new)
|
| 731 |
+
|
| 732 |
+
print(f"New Embodiment Output Shape: {rec_large.shape}")
|
| 733 |
+
|
| 734 |
+
if rec_large.shape == (1, 30, 10):
|
| 735 |
+
print("✅ PASS: New embodiment handled correctly with full dimensions.")
|
| 736 |
+
else:
|
| 737 |
+
print(f"❌ FAIL: Expected (1, 30, 10), got {rec_large.shape}")
|
| 738 |
+
|
| 739 |
+
# ------------------------------------------
|
| 740 |
+
# 7. Mixed Batch Processing (Masking)
|
| 741 |
+
# ------------------------------------------
|
| 742 |
+
print("\n=== Test 7: Mixed Batch Processing ===")
|
| 743 |
+
|
| 744 |
+
# Batch size 2: [Robot 0 (20Hz, 7dim), Robot 2 (30Hz, 10dim)]
|
| 745 |
+
mixed_emb_ids = torch.tensor([0, 2], dtype=torch.long)
|
| 746 |
+
|
| 747 |
+
# Max seq len is 30. Max action dim is 10.
|
| 748 |
+
batch_input = torch.zeros(2, 30, 10)
|
| 749 |
+
|
| 750 |
+
# Fill data
|
| 751 |
+
# Batch 0: Length 20, Dim 7 valid
|
| 752 |
+
batch_input[0, :20, :7] = torch.randn(20, 7)
|
| 753 |
+
# Batch 1: Length 30, Dim 10 valid
|
| 754 |
+
batch_input[1, :30, :10] = torch.randn(30, 10)
|
| 755 |
+
|
| 756 |
+
# Encoder Mask: True = Valid
|
| 757 |
+
enc_padding_mask = torch.zeros(2, 30, dtype=torch.bool)
|
| 758 |
+
enc_padding_mask[0, :20] = True
|
| 759 |
+
enc_padding_mask[1, :30] = True
|
| 760 |
+
|
| 761 |
+
print("Running mixed batch...")
|
| 762 |
+
with torch.no_grad():
|
| 763 |
+
z_mixed = encoder(batch_input, mixed_emb_ids, padding_mask=enc_padding_mask)
|
| 764 |
+
rec_mixed, dec_padding_mask = decoder(z_mixed, mixed_emb_ids)
|
| 765 |
+
|
| 766 |
+
print(f"Mixed Reconstruction Shape: {rec_mixed.shape}") # Should be (2, 30, 10)
|
| 767 |
+
|
| 768 |
+
# Verify Decoder Generated Mask
|
| 769 |
+
valid_len_0 = dec_padding_mask[0].sum().item()
|
| 770 |
+
valid_len_1 = dec_padding_mask[1].sum().item()
|
| 771 |
+
|
| 772 |
+
print(f"Decoder Mask Valid Lengths: Batch 0={valid_len_0}, Batch 1={valid_len_1}")
|
| 773 |
+
|
| 774 |
+
if valid_len_0 == 20 and valid_len_1 == 30:
|
| 775 |
+
print("✅ PASS: Decoder correctly generated masks based on frequency and duration.")
|
| 776 |
+
else:
|
| 777 |
+
print("❌ FAIL: Decoder masks are incorrect.")
|
| 778 |
+
|
| 779 |
+
print("\n✨ All Tests Completed ✨")
|
rvq.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from vector_quantize_pytorch import VectorQuantize as torchVQ
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sample_vectors(samples, num):
|
| 13 |
+
# samples: (N, D), num_samples: N, feature dim: D
|
| 14 |
+
num_samples, device = samples.shape[0], samples.device
|
| 15 |
+
if num_samples >= num:
|
| 16 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 17 |
+
else:
|
| 18 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 19 |
+
return samples[indices].float() # (num, D), ensure fp32
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def ema_inplace(moving_avg, new, decay):
|
| 23 |
+
# moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
|
| 24 |
+
"""Update exponential moving average in-place"""
|
| 25 |
+
moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def kmeans(samples, num_clusters, num_iters=10):
|
| 29 |
+
# samples: (N, D), N samples with D dimensions
|
| 30 |
+
dim, _ = samples.shape[-1], torch.float32 # Force fp32
|
| 31 |
+
means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32
|
| 32 |
+
|
| 33 |
+
for _ in range(num_iters):
|
| 34 |
+
dists = -(
|
| 35 |
+
samples.float().pow(2).sum(1, keepdim=True) # (N, 1), ensure fp32
|
| 36 |
+
- 2 * samples.float() @ means.t() # (N, num_clusters), ensure fp32
|
| 37 |
+
+ means.t().float().pow(2).sum(0, keepdim=True)
|
| 38 |
+
) # (1, num_clusters), ensure fp32
|
| 39 |
+
# dists: (N, num_clusters)
|
| 40 |
+
buckets = dists.max(dim=-1).indices # (N)
|
| 41 |
+
bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters)
|
| 42 |
+
zero_mask = bins == 0 # (num_clusters)
|
| 43 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters)
|
| 44 |
+
|
| 45 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32
|
| 46 |
+
new_means.scatter_add_(
|
| 47 |
+
0, buckets.unsqueeze(1).expand(-1, dim), samples.float()
|
| 48 |
+
) # (num_clusters, D), ensure fp32
|
| 49 |
+
new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D)
|
| 50 |
+
means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D)
|
| 51 |
+
|
| 52 |
+
# Final cluster assignments for returning cluster sizes
|
| 53 |
+
dists = -(
|
| 54 |
+
samples.float().pow(2).sum(1, keepdim=True)
|
| 55 |
+
- 2 * samples.float() @ means.t()
|
| 56 |
+
+ means.t().float().pow(2).sum(0, keepdim=True)
|
| 57 |
+
) # (N, num_clusters), ensure fp32
|
| 58 |
+
buckets = dists.max(dim=-1).indices # (N)
|
| 59 |
+
bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32
|
| 60 |
+
|
| 61 |
+
return means, bins # (num_clusters, D), (num_clusters)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class VectorQuantize(nn.Module):
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
input_dim,
|
| 68 |
+
codebook_size,
|
| 69 |
+
codebook_dim,
|
| 70 |
+
commitment=1.0,
|
| 71 |
+
decay=0.99, # EMA decay
|
| 72 |
+
epsilon=1e-5, # Laplace smoothing epsilon
|
| 73 |
+
threshold_ema_dead=2, # Dead code threshold
|
| 74 |
+
kmeans_init=True, # Use kmeans initialization
|
| 75 |
+
kmeans_iters=10, # Kmeans iterations
|
| 76 |
+
rotation_trick=False, # Use rotation trick
|
| 77 |
+
**kwargs,
|
| 78 |
+
):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.input_dim = input_dim
|
| 81 |
+
self.codebook_size = codebook_size
|
| 82 |
+
self.codebook_dim = codebook_dim
|
| 83 |
+
self.commitment = commitment
|
| 84 |
+
self.decay = decay
|
| 85 |
+
self.epsilon = epsilon
|
| 86 |
+
self.threshold_ema_dead = threshold_ema_dead
|
| 87 |
+
self.kmeans_init = kmeans_init
|
| 88 |
+
self.kmeans_iters = kmeans_iters
|
| 89 |
+
self.rotation_trick = rotation_trick
|
| 90 |
+
|
| 91 |
+
if self.input_dim != self.codebook_dim:
|
| 92 |
+
self.in_project = nn.Linear(input_dim, codebook_dim)
|
| 93 |
+
self.out_project = nn.Linear(codebook_dim, input_dim)
|
| 94 |
+
else:
|
| 95 |
+
self.in_project = nn.Identity()
|
| 96 |
+
self.out_project = nn.Identity()
|
| 97 |
+
|
| 98 |
+
# Initialize codebook and EMA buffers
|
| 99 |
+
init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
|
| 100 |
+
self.register_buffer(
|
| 101 |
+
"codebook", init_fn(codebook_size, codebook_dim).float()
|
| 102 |
+
) # (codebook_size, D'), ensure fp32
|
| 103 |
+
self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1)
|
| 104 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32
|
| 105 |
+
self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32
|
| 106 |
+
|
| 107 |
+
def ema_update(self, encodings, embed_onehot):
|
| 108 |
+
# encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
|
| 109 |
+
"""Update codebook using EMA"""
|
| 110 |
+
encodings = encodings.float() # Ensure fp32
|
| 111 |
+
embed_onehot = embed_onehot.float() # Ensure fp32
|
| 112 |
+
cluster_size_new = embed_onehot.sum(0) # (codebook_size)
|
| 113 |
+
embed_sum = encodings.t() @ embed_onehot # (D', codebook_size)
|
| 114 |
+
|
| 115 |
+
# Distributed reduction
|
| 116 |
+
if dist.is_initialized():
|
| 117 |
+
dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
|
| 118 |
+
dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
|
| 119 |
+
|
| 120 |
+
ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size)
|
| 121 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D')
|
| 122 |
+
|
| 123 |
+
# Laplace smoothing
|
| 124 |
+
cluster_size = (self.cluster_size + self.epsilon) / (
|
| 125 |
+
self.cluster_size.sum() + self.codebook_size * self.epsilon
|
| 126 |
+
) # (codebook_size)
|
| 127 |
+
cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size)
|
| 128 |
+
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D')
|
| 129 |
+
|
| 130 |
+
def replace_dead_codes(self, encodings):
|
| 131 |
+
# encodings: (B*T, D')
|
| 132 |
+
"""Replace dead codes with random samples from current batch"""
|
| 133 |
+
if self.threshold_ema_dead == 0:
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size)
|
| 137 |
+
if dead_mask.any():
|
| 138 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 139 |
+
samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32
|
| 140 |
+
print(f"Replace {dead_mask.sum().item()} dead codes")
|
| 141 |
+
else:
|
| 142 |
+
samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32
|
| 143 |
+
|
| 144 |
+
# Broadcast samples
|
| 145 |
+
if dist.is_initialized():
|
| 146 |
+
dist.broadcast(samples, src=0)
|
| 147 |
+
|
| 148 |
+
self.codebook[dead_mask] = samples[: dead_mask.sum()].to(self.codebook.dtype) # Update dead codes
|
| 149 |
+
|
| 150 |
+
def init_codebook(self, encodings):
|
| 151 |
+
# encodings: (B*T, D')
|
| 152 |
+
"""Initialize codebook with k-means and update cluster_size"""
|
| 153 |
+
if self.inited.item():
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
| 157 |
+
embed, cluster_sizes = kmeans(
|
| 158 |
+
encodings.float(), self.codebook_size, self.kmeans_iters
|
| 159 |
+
) # (codebook_size, D'), (codebook_size), ensure fp32
|
| 160 |
+
else:
|
| 161 |
+
embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32
|
| 162 |
+
cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32
|
| 163 |
+
|
| 164 |
+
# Broadcast results
|
| 165 |
+
if dist.is_initialized():
|
| 166 |
+
dist.broadcast(embed, src=0)
|
| 167 |
+
dist.broadcast(cluster_sizes, src=0)
|
| 168 |
+
|
| 169 |
+
self.codebook.copy_(embed) # (codebook_size, D')
|
| 170 |
+
self.embed_avg.copy_(embed.clone()) # (codebook_size, D')
|
| 171 |
+
self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size)
|
| 172 |
+
self.inited.fill_(True)
|
| 173 |
+
|
| 174 |
+
def forward(self, z):
|
| 175 |
+
self = self.to(torch.float32)
|
| 176 |
+
z = z.float()
|
| 177 |
+
z_e = self.in_project(z).float()
|
| 178 |
+
|
| 179 |
+
# Rearrange for quantization
|
| 180 |
+
encodings = rearrange(z_e, "b t d -> (b t) d").float() # (B*T, D'), ensure fp32
|
| 181 |
+
|
| 182 |
+
# Initialize codebook if needed
|
| 183 |
+
if self.kmeans_init and not self.inited.item():
|
| 184 |
+
self.init_codebook(encodings)
|
| 185 |
+
|
| 186 |
+
dist = (
|
| 187 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 188 |
+
- 2 * encodings @ self.codebook.float().t()
|
| 189 |
+
+ self.codebook.float().pow(2).sum(1, keepdim=True).t()
|
| 190 |
+
)
|
| 191 |
+
indices = (-dist).max(1)[1]
|
| 192 |
+
|
| 193 |
+
# cosine_similarity = F.cosine_similarity(encodings[None], self.codebook[:, None], dim=-1)
|
| 194 |
+
# indices = cosine_similarity.max(dim=0)[1]
|
| 195 |
+
|
| 196 |
+
indices = rearrange(indices, "(b t) -> b t", b=z.size(0))
|
| 197 |
+
z_q = self.decode_code(indices).float()
|
| 198 |
+
commit_loss = F.mse_loss(z_e, z_q.detach()) * self.commitment
|
| 199 |
+
|
| 200 |
+
if self.training and torch.is_grad_enabled():
|
| 201 |
+
embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float()
|
| 202 |
+
self.ema_update(encodings, embed_onehot)
|
| 203 |
+
self.replace_dead_codes(encodings)
|
| 204 |
+
|
| 205 |
+
z_q = (z_q - z_e).detach() + z_e
|
| 206 |
+
z_q = self.out_project(z_q).float()
|
| 207 |
+
|
| 208 |
+
return (
|
| 209 |
+
z_q,
|
| 210 |
+
commit_loss,
|
| 211 |
+
torch.tensor(0.0, device=z.device, dtype=torch.float32),
|
| 212 |
+
indices,
|
| 213 |
+
z_e,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def decode_code(self, embed_id): # embed_id: (B, T)
|
| 217 |
+
return F.embedding(embed_id, self.codebook).float() # (B, D', T), ensure fp32
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# class VectorQuantize(nn.Module):
|
| 221 |
+
# """
|
| 222 |
+
# Implementation of VQ similar to Karpathy's repo:
|
| 223 |
+
# https://github.com/karpathy/deep-vector-quantization
|
| 224 |
+
# Additionally uses following tricks from Improved VQGAN
|
| 225 |
+
# (https://arxiv.org/pdf/2110.04627.pdf):
|
| 226 |
+
# 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
| 227 |
+
# for improved codebook usage
|
| 228 |
+
# 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
| 229 |
+
# improves training stability
|
| 230 |
+
# """
|
| 231 |
+
|
| 232 |
+
# def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
|
| 233 |
+
# super().__init__()
|
| 234 |
+
# self.codebook_size = codebook_size
|
| 235 |
+
# self.codebook_dim = codebook_dim
|
| 236 |
+
|
| 237 |
+
# self.in_proj = nn.Linear(input_dim, codebook_dim)
|
| 238 |
+
# self.out_proj = nn.Linear(codebook_dim, input_dim)
|
| 239 |
+
# self.codebook = nn.Embedding(codebook_size, codebook_dim)
|
| 240 |
+
|
| 241 |
+
# def forward(self, z: torch.Tensor):
|
| 242 |
+
# """
|
| 243 |
+
# Args:
|
| 244 |
+
# z (torch.Tensor): shape (b, t, d)
|
| 245 |
+
|
| 246 |
+
# Returns:
|
| 247 |
+
# z_q (torch.Tensor): shape (b, t, d)
|
| 248 |
+
# commitment_loss (torch.Tensor): shape (1)
|
| 249 |
+
# codebook_loss (torch.Tensor): shape (1)
|
| 250 |
+
# indices (torch.Tensor): shape (b, t)
|
| 251 |
+
# z_e (torch.Tensor): shape (b, t, d)
|
| 252 |
+
# """
|
| 253 |
+
|
| 254 |
+
# # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
|
| 255 |
+
# z_e = self.in_proj(z)
|
| 256 |
+
# z_q, indices = self.decode_latents(z_e)
|
| 257 |
+
|
| 258 |
+
# commitment_loss = F.mse_loss(z_e, z_q.detach()) * 0.25
|
| 259 |
+
# codebook_loss = F.mse_loss(z_q, z_e.detach())
|
| 260 |
+
|
| 261 |
+
# z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
|
| 262 |
+
|
| 263 |
+
# z_q = self.out_proj(z_q)
|
| 264 |
+
|
| 265 |
+
# return z_q, commitment_loss, codebook_loss, indices, z_e
|
| 266 |
+
|
| 267 |
+
# def embed_code(self, embed_id):
|
| 268 |
+
# return F.embedding(embed_id, self.codebook.weight)
|
| 269 |
+
|
| 270 |
+
# def decode_code(self, embed_id):
|
| 271 |
+
# return self.embed_code(embed_id)
|
| 272 |
+
|
| 273 |
+
# def decode_latents(self, latents: torch.Tensor):
|
| 274 |
+
# codebook = self.codebook.weight
|
| 275 |
+
# encodings = rearrange(latents, "b t d -> (b t) d")
|
| 276 |
+
|
| 277 |
+
# cosine_similarity = F.cosine_similarity(encodings[None], codebook[:, None], dim=-1)
|
| 278 |
+
# indices = cosine_similarity.max(dim=0)[1]
|
| 279 |
+
# indices = rearrange(indices, "(b t) -> b t", b=latents.size(0))
|
| 280 |
+
|
| 281 |
+
# # encodings = F.normalize(encodings)
|
| 282 |
+
# # codebook = F.normalize(codebook)
|
| 283 |
+
# # dist = (
|
| 284 |
+
# # encodings.pow(2).sum(1, keepdim=True)
|
| 285 |
+
# # - 2 * encodings @ codebook.t()
|
| 286 |
+
# # + codebook.pow(2).sum(1, keepdim=True).t()
|
| 287 |
+
# # )
|
| 288 |
+
# # indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 289 |
+
|
| 290 |
+
# z_q = self.decode_code(indices)
|
| 291 |
+
# return z_q, indices
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class ResidualVectorQuantize(nn.Module):
|
| 295 |
+
def __init__(
|
| 296 |
+
self,
|
| 297 |
+
dim: int = 256,
|
| 298 |
+
n_codebooks: int = 4,
|
| 299 |
+
codebook_size: int = 512,
|
| 300 |
+
codebook_dim: Union[int, list] = 8,
|
| 301 |
+
quantizer_dropout: float = 0.25,
|
| 302 |
+
commitment: float = 0.25,
|
| 303 |
+
decay: float = 0.99,
|
| 304 |
+
epsilon: float = 1e-5,
|
| 305 |
+
threshold_ema_dead: int = 2,
|
| 306 |
+
kmeans_init: bool = True,
|
| 307 |
+
kmeans_iters: int = 10,
|
| 308 |
+
rotation_trick: bool = False,
|
| 309 |
+
):
|
| 310 |
+
super().__init__()
|
| 311 |
+
if isinstance(codebook_dim, int):
|
| 312 |
+
codebook_dim = [codebook_dim for _ in range(n_codebooks)]
|
| 313 |
+
|
| 314 |
+
self.n_codebooks = n_codebooks
|
| 315 |
+
self.codebook_dim = codebook_dim
|
| 316 |
+
self.codebook_size = codebook_size
|
| 317 |
+
|
| 318 |
+
self.quantizers = nn.ModuleList(
|
| 319 |
+
[
|
| 320 |
+
VectorQuantize(
|
| 321 |
+
input_dim=dim,
|
| 322 |
+
codebook_size=codebook_size,
|
| 323 |
+
codebook_dim=codebook_dim[i],
|
| 324 |
+
commitment=commitment,
|
| 325 |
+
decay=decay,
|
| 326 |
+
epsilon=epsilon,
|
| 327 |
+
threshold_ema_dead=threshold_ema_dead,
|
| 328 |
+
kmeans_init=kmeans_init,
|
| 329 |
+
kmeans_iters=kmeans_iters,
|
| 330 |
+
rotation_trick=rotation_trick,
|
| 331 |
+
)
|
| 332 |
+
for i in range(n_codebooks)
|
| 333 |
+
]
|
| 334 |
+
)
|
| 335 |
+
self.quantizer_dropout = quantizer_dropout
|
| 336 |
+
|
| 337 |
+
def forward(self, z, n_quantizers: int = None):
|
| 338 |
+
"""Quantized the input tensor using a fixed set of `n` codebooks and returns
|
| 339 |
+
the corresponding codebook vectors
|
| 340 |
+
Parameters
|
| 341 |
+
----------
|
| 342 |
+
z : Tensor[B x D x T]
|
| 343 |
+
n_quantizers : int, optional
|
| 344 |
+
No. of quantizers to use
|
| 345 |
+
(n_quantizers < self.n_codebooks ex: for quantizer dropout)
|
| 346 |
+
Note: if `self.quantizer_dropout` is True, this argument is ignored
|
| 347 |
+
when in training mode, and a random number of quantizers is used.
|
| 348 |
+
Returns
|
| 349 |
+
-------
|
| 350 |
+
dict
|
| 351 |
+
A dictionary with the following keys:
|
| 352 |
+
|
| 353 |
+
"z" : Tensor[B x D x T]
|
| 354 |
+
Quantized continuous representation of input
|
| 355 |
+
"codes" : Tensor[B x N x T]
|
| 356 |
+
Codebook indices for each codebook
|
| 357 |
+
(quantized discrete representation of input)
|
| 358 |
+
"latents" : Tensor[B x N*D x T]
|
| 359 |
+
Projected latents (continuous representation of input before quantization)
|
| 360 |
+
"vq/commitment_loss" : Tensor[1]
|
| 361 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 362 |
+
entries
|
| 363 |
+
"vq/codebook_loss" : Tensor[1]
|
| 364 |
+
Codebook loss to update the codebook
|
| 365 |
+
"""
|
| 366 |
+
z_q, residual = 0, z
|
| 367 |
+
commitment_loss, codebook_loss = 0, 0
|
| 368 |
+
|
| 369 |
+
codebook_indices, latents = [], []
|
| 370 |
+
|
| 371 |
+
if n_quantizers is None:
|
| 372 |
+
n_quantizers = self.n_codebooks
|
| 373 |
+
if self.training:
|
| 374 |
+
n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
|
| 375 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
|
| 376 |
+
n_dropout = int(z.shape[0] * self.quantizer_dropout)
|
| 377 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 378 |
+
n_quantizers = n_quantizers.to(z.device)
|
| 379 |
+
|
| 380 |
+
for i, quantizer in enumerate(self.quantizers):
|
| 381 |
+
if self.training is False and i >= n_quantizers:
|
| 382 |
+
break
|
| 383 |
+
|
| 384 |
+
z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
|
| 385 |
+
|
| 386 |
+
# Create mask to apply quantizer dropout
|
| 387 |
+
mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
|
| 388 |
+
z_q = z_q + z_q_i * mask[:, None, None]
|
| 389 |
+
residual = residual - z_q_i
|
| 390 |
+
|
| 391 |
+
# Sum losses
|
| 392 |
+
commitment_loss += (commitment_loss_i * mask).mean()
|
| 393 |
+
codebook_loss += (codebook_loss_i * mask).mean()
|
| 394 |
+
|
| 395 |
+
codebook_indices.append(indices_i)
|
| 396 |
+
latents.append(z_e_i)
|
| 397 |
+
|
| 398 |
+
codes = torch.stack(codebook_indices, dim=-1)
|
| 399 |
+
latents = torch.cat(latents, dim=1)
|
| 400 |
+
|
| 401 |
+
return z_q, codes, latents, commitment_loss, codebook_loss
|
| 402 |
+
|
| 403 |
+
def from_codes(self, codes: torch.Tensor):
|
| 404 |
+
"""Given the quantized codes, reconstruct the continuous representation
|
| 405 |
+
Parameters
|
| 406 |
+
----------
|
| 407 |
+
codes : Tensor[B x N x T]
|
| 408 |
+
Quantized discrete representation of input
|
| 409 |
+
Returns
|
| 410 |
+
-------
|
| 411 |
+
Tensor[B x D x T]
|
| 412 |
+
Quantized continuous representation of input
|
| 413 |
+
"""
|
| 414 |
+
z_q = 0.0
|
| 415 |
+
z_p = []
|
| 416 |
+
n_codebooks = codes.shape[-1]
|
| 417 |
+
for i in range(n_codebooks):
|
| 418 |
+
z_p_i = self.quantizers[i].decode_code(codes[..., i])
|
| 419 |
+
z_p.append(z_p_i)
|
| 420 |
+
|
| 421 |
+
z_q_i = self.quantizers[i].out_project(z_p_i)
|
| 422 |
+
z_q = z_q + z_q_i
|
| 423 |
+
return z_q, torch.cat(z_p, dim=-1), codes
|
| 424 |
+
|
| 425 |
+
def from_latents(self, latents: torch.Tensor):
|
| 426 |
+
"""Given the unquantized latents, reconstruct the
|
| 427 |
+
continuous representation after quantization.
|
| 428 |
+
|
| 429 |
+
Parameters
|
| 430 |
+
----------
|
| 431 |
+
latents : Tensor[B x N x T]
|
| 432 |
+
Continuous representation of input after projection
|
| 433 |
+
|
| 434 |
+
Returns
|
| 435 |
+
-------
|
| 436 |
+
Tensor[B x D x T]
|
| 437 |
+
Quantized representation of full-projected space
|
| 438 |
+
Tensor[B x D x T]
|
| 439 |
+
Quantized representation of latent space
|
| 440 |
+
"""
|
| 441 |
+
z_q = 0
|
| 442 |
+
z_p = []
|
| 443 |
+
codes = []
|
| 444 |
+
dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
|
| 445 |
+
|
| 446 |
+
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
| 447 |
+
for i in range(n_codebooks):
|
| 448 |
+
j, k = dims[i], dims[i + 1]
|
| 449 |
+
z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
|
| 450 |
+
z_p.append(z_p_i)
|
| 451 |
+
codes.append(codes_i)
|
| 452 |
+
|
| 453 |
+
z_q_i = self.quantizers[i].out_proj(z_p_i)
|
| 454 |
+
z_q = z_q + z_q_i
|
| 455 |
+
|
| 456 |
+
return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class IndependentVectorQuantize(nn.Module):
|
| 460 |
+
def __init__(self, num_codebooks: int = 1, **kwargs):
|
| 461 |
+
super().__init__()
|
| 462 |
+
self.vector_quantizers = nn.ModuleList([torchVQ(**kwargs) for _ in range(num_codebooks)])
|
| 463 |
+
self.num_codebooks = num_codebooks
|
| 464 |
+
self.codebook_size = self.vector_quantizers[0].codebook_size
|
| 465 |
+
|
| 466 |
+
@property
|
| 467 |
+
def ema_update(self):
|
| 468 |
+
return [vq.ema_update for vq in self.vector_quantizers]
|
| 469 |
+
|
| 470 |
+
@property
|
| 471 |
+
def codebook(self):
|
| 472 |
+
return torch.stack([vq.codebook for vq in self.vector_quantizers], dim=0)
|
| 473 |
+
|
| 474 |
+
@codebook.setter
|
| 475 |
+
def codebook(self, codes: List[torch.Tensor]):
|
| 476 |
+
assert len(codes) == self.num_codebooks, "Number of codebooks must match"
|
| 477 |
+
if not self.separate_codebook_per_head:
|
| 478 |
+
codes = rearrange(codes, "... -> 1 ...")
|
| 479 |
+
|
| 480 |
+
for i, code in enumerate(codes):
|
| 481 |
+
self.vector_quantizers[i].codebook.copy_(code)
|
| 482 |
+
|
| 483 |
+
def get_codes_from_indices(self, indices: torch.Tensor):
|
| 484 |
+
codes = list()
|
| 485 |
+
for i in range(self.num_codebooks):
|
| 486 |
+
codes.append(self.vector_quantizers[i].get_codes_from_indices(indices[..., i : i + 1]))
|
| 487 |
+
return torch.cat(codes, dim=-2)
|
| 488 |
+
|
| 489 |
+
def get_output_from_indices(self, indices: torch.Tensor):
|
| 490 |
+
outputs = list()
|
| 491 |
+
for i in range(self.num_codebooks):
|
| 492 |
+
outputs.append(self.vector_quantizers[i].get_output_from_indices(indices[..., i : i + 1]))
|
| 493 |
+
return torch.cat(outputs, dim=-2)
|
| 494 |
+
|
| 495 |
+
def update_in_place_optimizer(self):
|
| 496 |
+
for i in range(self.num_codebooks):
|
| 497 |
+
self.vector_quantizers[i].update_in_place_optimizer()
|
| 498 |
+
|
| 499 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
| 500 |
+
assert x.shape[1] == self.num_codebooks
|
| 501 |
+
quantized, indices, commit_losses = list(), list(), 0
|
| 502 |
+
for i in range(self.num_codebooks):
|
| 503 |
+
quantized_i, indices_i, commit_loss_i = self.vector_quantizers[i](x[:, i : i + 1])
|
| 504 |
+
quantized.append(quantized_i)
|
| 505 |
+
indices.append(indices_i)
|
| 506 |
+
commit_losses += commit_loss_i
|
| 507 |
+
quantized = torch.cat(quantized, dim=-2)
|
| 508 |
+
indices = torch.cat(indices, dim=-1)
|
| 509 |
+
return quantized, indices, commit_losses / self.num_codebooks
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
if __name__ == "__main__":
|
| 513 |
+
vq = IndependentVectorQuantize(
|
| 514 |
+
num_codebooks=16,
|
| 515 |
+
dim=256,
|
| 516 |
+
codebook_size=2048,
|
| 517 |
+
decay=0.8, # the exponential moving average decay, lower means the dictionary will change faster
|
| 518 |
+
commitment_weight=1.0, # the weight on the commitment loss
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
x = torch.randn(1, 16, 256)
|
| 522 |
+
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
|