File size: 28,287 Bytes
47e954c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
"""ACT: Action Chunking with Transformers implementation.

This module implements the ACT (Action Chunking with Transformers) model
from "Learning fine-grained bimanual manipulation with low-cost hardware"
(Zhao et al., 2023). ACT uses a transformer architecture with latent variable
modeling to predict action sequences for robot manipulation tasks.

Reference: Zhao, Tony Z., et al. "Learning fine-grained bimanual manipulation
with low-cost hardware." arXiv preprint arXiv:2304.13705 (2023).
"""

import logging
from typing import cast

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from neuracore_types import (
    BatchedJointData,
    BatchedNCData,
    BatchedParallelGripperOpenAmountData,
    BatchedRGBData,
    CameraDataStats,
    DataItemStats,
    DataType,
    JointDataStats,
    ModelInitDescription,
    ParallelGripperOpenAmountDataStats,
)

from neuracore.ml import (
    BatchedInferenceInputs,
    BatchedTrainingOutputs,
    BatchedTrainingSamples,
    NeuracoreModel,
)
from neuracore.ml.algorithm_utils.normalizer import MeanStdNormalizer

from .modules import (
    ACTImageEncoder,
    PositionalEncoding,
    TransformerDecoder,
    TransformerEncoder,
)

logger = logging.getLogger(__name__)

PROPRIO_NORMALIZER = MeanStdNormalizer  # or MinMaxNormalizer
ACTION_NORMALIZER = MeanStdNormalizer  # or MinMaxNormalizer
RESNET_MEAN = [0.485, 0.456, 0.406]
RESNET_STD = [0.229, 0.224, 0.225]


class ACT(NeuracoreModel):
    """Implementation of ACT (Action Chunking Transformer) model.

    ACT is a transformer-based architecture that learns to predict sequences
    of robot actions by encoding visual observations and proprioceptive state
    into a latent representation, then decoding action chunks autoregressively.

    The model uses a variational autoencoder framework with separate encoders
    for visual features and action sequences, combined with a transformer
    decoder for action generation.
    """

    def __init__(
        self,
        model_init_description: ModelInitDescription,
        hidden_dim: int = 512,
        num_encoder_layers: int = 4,
        num_decoder_layers: int = 1,
        nheads: int = 8,
        dim_feedforward: int = 3200,
        dropout: float = 0.1,
        use_resnet_stats: bool = True,
        lr: float = 1e-4,
        freeze_backbone: bool = False,
        lr_backbone: float = 1e-5,
        weight_decay: float = 1e-4,
        kl_weight: float = 10.0,
        latent_dim: int = 512,
    ):
        """Initialize the ACT model.

        Args:
            model_init_description: Model initialization parameters
            hidden_dim: Hidden dimension for transformer layers
            num_encoder_layers: Number of transformer encoder layers
            num_decoder_layers: Number of transformer decoder layers
            nheads: Number of attention heads
            dim_feedforward: Feedforward network dimension
            dropout: Dropout probability
            use_resnet_stats: Whether to use ResNet normalization statistics
            lr: Learning rate for main parameters
            freeze_backbone: Whether to freeze image encoder backbone
            lr_backbone: Learning rate for image encoder backbone
            weight_decay: Weight decay for optimizer
            kl_weight: Weight for KL divergence loss
            latent_dim: Dimension of latent variable space
        """
        super().__init__(model_init_description)
        self.hidden_dim = hidden_dim
        self.use_resnet_stats = use_resnet_stats
        self.freeze_backbone = freeze_backbone
        self.lr = lr
        self.lr_backbone = lr_backbone
        self.weight_decay = weight_decay
        self.kl_weight = kl_weight
        self.latent_dim = latent_dim

        data_stats: dict[DataType, DataItemStats] = {}

        # Setup proprioceptive data
        self.proprio_dims: dict[DataType, tuple[int, int]] = {}
        proprio_stats = []
        current_dim = 0

        for data_type in [
            DataType.JOINT_POSITIONS,
            DataType.JOINT_VELOCITIES,
            DataType.JOINT_TORQUES,
            DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
        ]:
            if data_type in self.data_types:
                if data_type == DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS:
                    stats = cast(
                        list[ParallelGripperOpenAmountDataStats],
                        self.dataset_statistics[data_type],
                    )
                    combined_stats = DataItemStats()
                    for stat in stats:
                        combined_stats = combined_stats.concatenate(stat.open_amount)
                    data_stats[data_type] = combined_stats
                else:
                    stats = cast(
                        list[JointDataStats], self.dataset_statistics[data_type]
                    )
                    combined_stats = DataItemStats()
                    for stat in stats:
                        combined_stats = combined_stats.concatenate(stat.value)
                    data_stats[data_type] = combined_stats

                if data_type in self.input_data_types:
                    proprio_stats.append(combined_stats)
                    dim = len(combined_stats.mean)
                    self.proprio_dims[data_type] = (current_dim, current_dim + dim)
                    current_dim += dim

        # State embedding
        state_input_dim = current_dim
        self.state_embed = None
        if state_input_dim > 0:
            self.state_embed = nn.Linear(state_input_dim, hidden_dim)

        # Setup output data
        self.max_output_size = 0
        output_stats = []
        for data_type in self.output_data_types:
            if data_type in [DataType.JOINT_TARGET_POSITIONS, DataType.JOINT_POSITIONS]:
                stats = cast(
                    list[JointDataStats],
                    self.dataset_statistics[data_type],
                )
                combined_stats = DataItemStats()
                for stat in stats:
                    combined_stats = combined_stats.concatenate(stat.value)
            elif data_type in [
                DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
                DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
            ]:
                stats = cast(
                    list[ParallelGripperOpenAmountDataStats],
                    self.dataset_statistics[data_type],
                )
                combined_stats = DataItemStats()
                for stat in stats:
                    combined_stats = combined_stats.concatenate(stat.open_amount)
            else:
                raise ValueError(f"Unsupported output data type: {data_type}")

            data_stats[data_type] = combined_stats
            output_stats.append(combined_stats)
            self.max_output_size += len(combined_stats.mean)

        # Action embedding
        self.action_embed = nn.Linear(self.max_output_size, hidden_dim)

        # Setup normalizers
        # Only create proprio_normalizer if there are proprioception stats
        # This allows the algorithm to work without proprioception (visual-only)
        self.proprio_normalizer = (
            PROPRIO_NORMALIZER(name="proprioception", statistics=proprio_stats)
            if proprio_stats
            else None
        )
        self.action_normalizer = ACTION_NORMALIZER(
            name="actions", statistics=output_stats
        )

        # Vision components
        if DataType.RGB_IMAGES in self.input_data_types:
            stats = cast(
                list[CameraDataStats], self.dataset_statistics[DataType.RGB_IMAGES]
            )
            max_cameras = len(stats)
            self.image_encoders = nn.ModuleList()
            for i in range(max_cameras):
                if use_resnet_stats:
                    mean, std = RESNET_MEAN, RESNET_STD
                else:
                    mean_c_h_w, std_c_h_w = stats[i].frame.mean, stats[i].frame.std
                    mean = mean_c_h_w.mean(axis=(1, 2)).tolist()
                    std = std_c_h_w.mean(axis=(1, 2)).tolist()

                encoder = nn.ModuleDict({
                    "transform": torch.nn.Sequential(
                        T.Resize((224, 224)),
                        T.Normalize(mean=mean, std=std),
                    ),
                    "encoder": ACTImageEncoder(output_dim=hidden_dim),
                })
                self.image_encoders.append(encoder)

        # CLS token embedding for latent encoder
        self.cls_embed = nn.Parameter(torch.randn(1, 1, hidden_dim))

        # Main transformer for vision and action generation
        self.transformer = nn.ModuleDict({
            "encoder": TransformerEncoder(
                d_model=hidden_dim,
                nhead=nheads,
                num_encoder_layers=num_encoder_layers,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
            ),
            "decoder": TransformerDecoder(
                d_model=hidden_dim,
                nhead=nheads,
                num_decoder_layers=num_decoder_layers,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
            ),
        })

        # Separate encoder for latent space
        self.latent_encoder = TransformerEncoder(
            d_model=hidden_dim,
            nhead=nheads,
            num_encoder_layers=num_encoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )

        # Positional encoding
        self.pos_encoder = PositionalEncoding(hidden_dim, dropout)

        # Output heads
        self.action_head = nn.Linear(hidden_dim, self.max_output_size)

        # Latent projections
        self.latent_mu = nn.Linear(hidden_dim, latent_dim)
        self.latent_logvar = nn.Linear(hidden_dim, latent_dim)
        self.latent_out_proj = nn.Linear(latent_dim, hidden_dim)

        # Query embedding for decoding
        self.query_embed = nn.Parameter(
            torch.randn(self.output_prediction_horizon, 1, hidden_dim)
        )

        # Additional position embeddings for proprio and latent
        self.additional_pos_embed = nn.Parameter(torch.randn(2, 1, hidden_dim))

        # Setup parameter groups
        self._setup_optimizer_param_groups()

    def _setup_optimizer_param_groups(self) -> None:
        """Setup parameter groups for optimizer."""
        backbone_params, other_params = [], []
        for name, param in self.named_parameters():
            if any(backbone in name for backbone in ["image_encoders"]):
                backbone_params.append(param)
            else:
                other_params.append(param)

        if self.freeze_backbone:
            for param in backbone_params:
                param.requires_grad = False
            self.param_groups = [{"params": other_params, "lr": self.lr}]
        else:
            self.param_groups = [
                {"params": backbone_params, "lr": self.lr_backbone},
                {"params": other_params, "lr": self.lr},
            ]

    def _reparametrize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """Sample from latent distribution using reparametrization trick.

        During training, samples from the distribution N(mu, exp(logvar)).
        During inference, returns the mean mu.

        Args:
            mu: Mean of latent distribution
            logvar: Log variance of latent distribution

        Returns:
            torch.Tensor: Sampled latent variable
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu

    def _combine_proprio(self, batch: BatchedInferenceInputs) -> torch.FloatTensor:
        """Combine different types of joint state data.

        Concatenates joint positions, velocities, and torques into a single
        feature vector, applying masks and normalization.

        Args:
            batch: Input batch containing joint state data

        Returns:
            torch.FloatTensor: Combined and normalized joint state features
        """
        if self.state_embed is None:
            return None

        proprio_list = []
        for data_type in [
            DataType.JOINT_POSITIONS,
            DataType.JOINT_VELOCITIES,
            DataType.JOINT_TORQUES,
            DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
        ]:
            if data_type not in batch.inputs:
                continue

            batched_nc_data = batch.inputs[data_type]
            mask = batch.inputs_mask[data_type]

            if data_type == DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS:
                batched_gripper_data = cast(
                    list[BatchedParallelGripperOpenAmountData], batched_nc_data
                )
                proprio_data = torch.cat(
                    [bgd.open_amount for bgd in batched_gripper_data], dim=-1
                )
            else:
                batched_joint_data = cast(list[BatchedJointData], batched_nc_data)
                proprio_data = torch.cat(
                    [bjd.value for bjd in batched_joint_data], dim=-1
                )

            last_proprio = proprio_data[:, -1, :]  # (B, num_features)
            masked_proprio = last_proprio * mask
            proprio_list.append(masked_proprio)

        # If no proprioception data is available, return None
        # This allows the algorithm to work with visual-only inputs
        if not proprio_list:
            return None

        # Concatenate all proprio together: (B, total_proprio_dim)
        all_proprio = torch.cat(proprio_list, dim=-1)

        # Normalize once on all proprio
        # Check if normalizer exists (it should if we have proprio data)
        if self.proprio_normalizer is None:
            raise ValueError(
                "Proprioception inputs were provided but no normalizer was available."
            )
        normalized_proprio = self.proprio_normalizer.normalize(all_proprio)

        return normalized_proprio

    def _encode_latent(
        self,
        state: torch.FloatTensor,
        actions: torch.FloatTensor,
        actions_mask: torch.FloatTensor,
        actions_sequence_mask: torch.FloatTensor,
    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
        """Encode actions to latent space during training.

        Uses a separate transformer encoder to encode the action sequence
        along with proprioceptive state into latent distribution parameters.

        Args:
            state: Proprioceptive state features
            actions: Target action sequence
            actions_mask: Mask for valid action dimensions
            actions_sequence_mask: Mask for valid sequence positions

        Returns:
            tuple[torch.FloatTensor, torch.FloatTensor]: Latent mean and log variance
        """
        batch_size = state.shape[0]

        # Project joint positions and actions
        state_embed = (
            self.state_embed(state) if self.state_embed is not None else None
        )  # [B, H]
        action_embed = self.action_embed(
            actions * actions_mask.unsqueeze(1)
        )  # [B, T, H]

        # Reshape to sequence first
        state_embed = (
            state_embed.unsqueeze(0) if state_embed is not None else None
        )  # [1, B, H]
        action_embed = action_embed.transpose(0, 1)  # [T, B, H]

        # Concatenate [CLS, state_emb, action_embed]
        cls_token = self.cls_embed.expand(-1, batch_size, -1)  # [1, B, H]
        encoder_input = torch.cat([cls_token, state_embed, action_embed], dim=0)

        # # Update padding mask
        if actions_sequence_mask is not None:
            cls_joint_pad = torch.zeros(
                batch_size, 2, dtype=torch.bool, device=self.device
            )
            actions_sequence_mask = torch.cat(
                [cls_joint_pad, actions_sequence_mask], dim=1
            )

        # Add positional encoding
        encoder_input = self.pos_encoder(encoder_input)

        # Encode sequence
        memory = self.latent_encoder(
            encoder_input, src_key_padding_mask=actions_sequence_mask
        )

        # Get latent parameters from CLS token
        mu = self.latent_mu(memory[0])  # Take CLS token output
        logvar = self.latent_logvar(memory[0])
        return mu, logvar

    def _encode_visual(
        self,
        states: torch.FloatTensor,
        batched_nc_data: list[BatchedNCData],
        camera_images_mask: torch.FloatTensor,
        latent: torch.FloatTensor,
    ) -> torch.FloatTensor:
        """Encode visual inputs with latent and proprioceptive features.

        Processes RGB images through vision encoders and combines them with
        proprioceptive state and latent features using a transformer encoder.

        Args:
            states: Proprioceptive state features
            batched_nc_data: List of BatchedRGBData
            camera_images_mask: Mask for valid camera inputs
            latent: Latent features from action encoding

        Returns:
            torch.FloatTensor: Encoded visual and proprioceptive memory
        """
        batched_rgb_data = cast(list[BatchedRGBData], batched_nc_data)
        batch_size = states.shape[0]

        # Process images
        image_features = []
        image_pos = []
        for cam_id, (encoder_dict, input_rgb) in enumerate(
            zip(self.image_encoders, batched_rgb_data)
        ):
            last_frame = input_rgb.frame[:, -1, :, :, :]  # (B, 3, H, W)
            transformed = encoder_dict["transform"](last_frame)
            features, pos = encoder_dict["encoder"](
                transformed
            )  # Vision backbone provides features and pos
            features *= camera_images_mask[:, cam_id].view(batch_size, 1, 1, 1)
            image_features.append(features)
            image_pos.append(pos)

        # Combine image features and positions
        combined_features = torch.cat(image_features, dim=3)  # [B, C, H, W]
        combined_pos = torch.cat(image_pos, dim=3)  # [B, C, H, W]

        # Convert to sequence [H*W, B, C]
        src = combined_features.flatten(2).permute(2, 0, 1)
        pos = combined_pos.flatten(2).permute(2, 0, 1)

        # Process joint positions and latent
        # If no proprioception, create zero tensor with same shape as latent
        if states is None:
            state_features = torch.zeros_like(latent)  # [B, H]
        else:
            state_features = (
                self.state_embed(states)
                if self.state_embed is not None
                else torch.zeros_like(latent)
            )  # [B, H]

        # Stack latent and proprio features
        additional_features = torch.stack([latent, state_features], dim=0)  # [2, B, H]

        # Add position embeddings from additional_pos_embed
        additional_pos = self.additional_pos_embed.expand(
            -1, batch_size, -1
        )  # [2, B, H]

        # Concatenate everything
        src = torch.cat([additional_features, src], dim=0)
        pos = torch.cat([additional_pos, pos], dim=0)

        # Fuse positional embeddings with source
        src = src + pos

        # Encode
        memory = self.transformer["encoder"](src)

        return memory

    def _decode(
        self,
        latent: torch.FloatTensor,
        memory: torch.FloatTensor,
    ) -> torch.Tensor:
        """Decode latent and visual features to action sequence.

        Uses a transformer decoder with learned query embeddings to generate
        a sequence of action predictions conditioned on visual and latent features.

        Args:
            latent: Latent features
            memory: Encoded visual and proprioceptive memory

        Returns:
            torch.Tensor: Predicted action sequence [B, T, action_dim]
        """
        batch_size = latent.shape[0]

        # Convert to sequence first and expand
        query_embed = self.query_embed.expand(-1, batch_size, -1)  # [T, B, H]
        latent = latent.unsqueeze(0).expand_as(query_embed)  # [T, B, H]

        # Add latent to query embedding
        query_embed = query_embed + latent

        # Initialize target with zeros
        tgt = torch.zeros_like(query_embed)

        # Decode sequence
        hs = self.transformer["decoder"](tgt, memory, query_pos=query_embed)

        # Project to action space (keeping sequence first)
        actions = self.action_head(hs)  # [T, B, A]

        # Convert back to batch first
        actions = actions.transpose(0, 1)  # [B, T, A]

        return actions

    def _predict_action(
        self,
        mu: torch.FloatTensor,
        logvar: torch.FloatTensor,
        batch: BatchedInferenceInputs,
    ) -> torch.FloatTensor:
        """Predict action sequence from latent distribution and observations.

        Args:
            mu: Mean of latent distribution
            logvar: Log variance of latent distribution
            batch: Input observations

        Returns:
            torch.FloatTensor: Predicted action sequence
        """
        # Sample latent
        latent_sample = self._reparametrize(mu, logvar)

        # Project latent
        latent = self.latent_out_proj(latent_sample)  # [B, H]

        if DataType.RGB_IMAGES not in batch.inputs:
            raise ValueError("No RGB images in batch")

        # Encode visual features
        proprio_state = self._combine_proprio(batch)
        memory = self._encode_visual(
            proprio_state,
            batch.inputs[DataType.RGB_IMAGES],
            batch.inputs_mask[DataType.RGB_IMAGES],
            latent,
        )

        # Decode actions
        action_preds = self._decode(latent, memory)
        return action_preds

    def forward(
        self, batch: BatchedInferenceInputs
    ) -> dict[DataType, list[BatchedNCData]]:
        """Perform inference to predict action sequence.

        Args:
            batch: Input batch with observations

        Returns:
            dict[DataType, list[BatchedNCData]]: Model predictions with action sequences
        """
        batch_size = len(batch)
        mu = torch.zeros(batch_size, self.latent_dim, device=self.device)
        logvar = torch.zeros(batch_size, self.latent_dim, device=self.device)
        action_preds = self._predict_action(mu, logvar, batch)

        # (B, T, action_dim)
        predictions = self.action_normalizer.unnormalize(action_preds)

        output_tensors: dict[DataType, list[BatchedNCData]] = {}

        start_slice_idx = 0
        for data_type in self.output_data_types:
            end_slice_idx = start_slice_idx + len(self.dataset_statistics[data_type])
            dt_preds = predictions[
                :, :, start_slice_idx:end_slice_idx
            ]  # (B, T, dt_size)

            if data_type in [DataType.JOINT_TARGET_POSITIONS, DataType.JOINT_POSITIONS]:
                batched_outputs = []
                for i in range(len(self.dataset_statistics[data_type])):
                    joint_preds = dt_preds[:, :, i : i + 1]  # (B, T, 1)
                    batched_outputs.append(BatchedJointData(value=joint_preds))
                output_tensors[data_type] = batched_outputs
            elif data_type in [
                DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
                DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
            ]:
                batched_outputs = []
                for i in range(len(self.dataset_statistics[data_type])):
                    gripper_preds = dt_preds[:, :, i : i + 1]  # (B, T, 1)
                    batched_outputs.append(
                        BatchedParallelGripperOpenAmountData(open_amount=gripper_preds)
                    )
                output_tensors[data_type] = batched_outputs
            else:
                raise ValueError(f"Unsupported output data type: {data_type}")

            start_slice_idx = end_slice_idx

        return output_tensors

    def training_step(self, batch: BatchedTrainingSamples) -> BatchedTrainingOutputs:
        """Perform a single training step.

        Encodes action sequences to latent space, predicts actions, and computes
        L1 reconstruction loss plus KL divergence regularization.

        Args:
            batch: Training batch with inputs and targets

        Returns:
            BatchedTrainingOutputs: Training outputs with losses and metrics
        """
        inference_sample = BatchedInferenceInputs(
            inputs=batch.inputs,
            inputs_mask=batch.inputs_mask,
            batch_size=batch.batch_size,
        )

        # Extract target actions
        action_targets = []
        for data_type in self.output_data_types:
            if data_type in [DataType.JOINT_TARGET_POSITIONS, DataType.JOINT_POSITIONS]:
                batched_joints = cast(list[BatchedJointData], batch.outputs[data_type])
                action_targets.extend([bjd.value for bjd in batched_joints])
            elif data_type in [
                DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
                DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
            ]:
                grippers = cast(
                    list[BatchedParallelGripperOpenAmountData], batch.outputs[data_type]
                )
                action_targets.extend([gripper.open_amount for gripper in grippers])
            else:
                raise ValueError(f"Unsupported output data type: {data_type}")

        action_data = torch.cat(action_targets, dim=-1)  # (B, T, action_dim)

        # Get masks
        pred_sequence_mask = torch.ones_like(
            action_data[:, :, 0]
        )  # All time steps valid
        max_action_mask = torch.ones(
            batch.batch_size, self.max_output_size, device=self.device
        )  # All actions valid

        proprio_state = self._combine_proprio(inference_sample)

        # Normalize actions for encoding
        normalized_actions = self.action_normalizer.normalize(action_data)

        mu, logvar = self._encode_latent(
            proprio_state,
            normalized_actions,
            max_action_mask,
            pred_sequence_mask,
        )

        action_preds = self._predict_action(mu, logvar, inference_sample)
        target_actions = self.action_normalizer.normalize(action_data)

        l1_loss_all = F.l1_loss(action_preds, target_actions, reduction="none")
        l1_loss = (l1_loss_all * pred_sequence_mask.unsqueeze(-1)).mean()
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.shape[0]
        loss = l1_loss + self.kl_weight * kl_loss

        losses = {
            "l1_and_kl_loss": loss,
        }
        metrics = {
            "l1_loss": l1_loss,
            "kl_loss": kl_loss,
        }

        return BatchedTrainingOutputs(
            losses=losses,
            metrics=metrics,
        )

    def configure_optimizers(
        self,
    ) -> list[torch.optim.Optimizer]:
        """Configure optimizer with different learning rates.

        Uses separate learning rates for image encoder backbone and other
        model parameters.

        Returns:
            list[torch.optim.Optimizer]: List of optimizers for model parameters
        """
        return [torch.optim.AdamW(self.param_groups, weight_decay=self.weight_decay)]

    @staticmethod
    def get_supported_input_data_types() -> set[DataType]:
        """Get the input data types supported by this model.

        Returns:
            set[DataType]: Set of supported input data types
        """
        return {
            DataType.JOINT_POSITIONS,
            DataType.JOINT_VELOCITIES,
            DataType.JOINT_TORQUES,
            DataType.RGB_IMAGES,
            DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
        }

    @staticmethod
    def get_supported_output_data_types() -> set[DataType]:
        """Get the output data types supported by this model.

        Returns:
            set[DataType]: Set of supported output data types
        """
        return {
            DataType.JOINT_POSITIONS,
            DataType.JOINT_TARGET_POSITIONS,
            DataType.PARALLEL_GRIPPER_OPEN_AMOUNTS,
            DataType.PARALLEL_GRIPPER_TARGET_OPEN_AMOUNTS,
        }