| """Neural network modules for the ACT (Action Chunking Transformer) model. |
| |
| This module provides the core building blocks for the ACT architecture including |
| image encoders, positional encodings, and transformer encoder/decoder layers. |
| These components are designed to work together for robot manipulation tasks. |
| """ |
|
|
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchvision.models as models |
|
|
|
|
| class ACTImageEncoder(nn.Module): |
| """Encode images using ResNet backbone with spatial position embeddings. |
| |
| Maintains spatial dimensions and provides learnable position embeddings |
| similar to DETR's backbone implementation. Uses a pretrained ResNet18 |
| backbone with projection layers for feature extraction. |
| """ |
|
|
| def __init__(self, output_dim: int = 256): |
| """Initialize the image encoder. |
| |
| Args: |
| output_dim: Output feature dimension after projection |
| """ |
| super().__init__() |
| |
| self.backbone = self._build_backbone() |
| self.proj = nn.Conv2d(512, output_dim, kernel_size=1) |
|
|
| |
| self.row_embed = nn.Embedding(50, output_dim // 2) |
| self.col_embed = nn.Embedding(50, output_dim // 2) |
| self.reset_parameters() |
|
|
| def _build_backbone(self) -> nn.Module: |
| """Build backbone CNN, removing avgpool and fc layers. |
| |
| Returns: |
| nn.Module: ResNet18 backbone without classification layers |
| """ |
| resnet = models.get_model("resnet18", weights="DEFAULT") |
| return nn.Sequential(*list(resnet.children())[:-2]) |
|
|
| def reset_parameters(self) -> None: |
| """Initialize position embeddings with uniform distribution.""" |
| nn.init.uniform_(self.row_embed.weight) |
| nn.init.uniform_(self.col_embed.weight) |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """Forward pass through image encoder. |
| |
| Args: |
| x: Image tensor of shape (batch, channels, height, width) |
| |
| Returns: |
| tuple[torch.Tensor, torch.Tensor]: |
| - features: Encoded features of shape (batch, output_dim, height, width) |
| - pos: Position embeddings of shape (batch, output_dim, height, width) |
| """ |
| |
| x = self.backbone(x) |
| features = self.proj(x) |
|
|
| |
| h, w = features.shape[-2:] |
| i = torch.arange(w, device=x.device) |
| j = torch.arange(h, device=x.device) |
| x_emb = self.col_embed(i) |
| y_emb = self.row_embed(j) |
|
|
| pos = ( |
| torch.cat( |
| [ |
| x_emb.unsqueeze(0).repeat(h, 1, 1), |
| y_emb.unsqueeze(1).repeat(1, w, 1), |
| ], |
| dim=-1, |
| ) |
| .permute(2, 0, 1) |
| .unsqueeze(0) |
| ) |
|
|
| pos = pos.repeat(x.shape[0], 1, 1, 1) |
|
|
| return features, pos |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| """Sinusoidal positional encoding for transformer sequences. |
| |
| Implements the standard sinusoidal positional encoding as described |
| in "Attention Is All You Need" (Vaswani et al., 2017). |
| """ |
|
|
| def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): |
| """Initialize positional encoding. |
| |
| Args: |
| d_model: Model dimension |
| dropout: Dropout probability |
| max_len: Maximum sequence length to support |
| """ |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
|
|
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
| ) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(1) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Add positional encoding to input tensor. |
| |
| Args: |
| x: Input tensor of shape (seq_len, batch, d_model) |
| |
| Returns: |
| torch.Tensor: Input with positional encoding added |
| """ |
| x = x + self.pe[: x.size(0), :, :] |
| return self.dropout(x) |
|
|
|
|
| class TransformerEncoderLayer(nn.Module): |
| """Single transformer encoder layer with pre-layer normalization. |
| |
| Implements a transformer encoder layer following the pre-norm architecture |
| with multi-head self-attention and position-wise feed-forward network. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| nhead: int, |
| dim_feedforward: int = 2048, |
| dropout: float = 0.1, |
| ): |
| """Initialize transformer encoder layer. |
| |
| Args: |
| d_model: Model dimension |
| nhead: Number of attention heads |
| dim_feedforward: Feedforward network hidden dimension |
| dropout: Dropout probability |
| """ |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
|
|
| def forward( |
| self, |
| src: torch.Tensor, |
| src_mask: torch.Tensor | None = None, |
| src_key_padding_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Forward pass through encoder layer. |
| |
| Args: |
| src: Input tensor of shape (seq_len, batch, d_model) |
| src_mask: Optional attention mask |
| src_key_padding_mask: Optional key padding mask |
| |
| Returns: |
| torch.Tensor: Output tensor of same shape as input |
| """ |
| src2 = self.norm1(src) |
| src2, _ = self.self_attn( |
| src2, src2, src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask |
| ) |
| src = src + self.dropout1(src2) |
|
|
| src2 = self.norm2(src) |
| src2 = self.linear2(self.dropout(F.relu(self.linear1(src2)))) |
| src = src + self.dropout2(src2) |
|
|
| return src |
|
|
|
|
| class TransformerDecoderLayer(nn.Module): |
| """Single transformer decoder layer with cross-attention. |
| |
| Implements a transformer decoder layer with masked self-attention, |
| cross-attention to encoder memory, and position-wise feed-forward network. |
| Supports query position embeddings for object detection style architectures. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| nhead: int, |
| dim_feedforward: int = 2048, |
| dropout: float = 0.1, |
| ): |
| """Initialize transformer decoder layer. |
| |
| Args: |
| d_model: Model dimension |
| nhead: Number of attention heads |
| dim_feedforward: Feedforward network hidden dimension |
| dropout: Dropout probability |
| """ |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
| self.dropout = nn.Dropout(dropout) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
| self.dropout3 = nn.Dropout(dropout) |
|
|
| def forward( |
| self, |
| tgt: torch.Tensor, |
| memory: torch.Tensor, |
| tgt_mask: torch.Tensor | None = None, |
| memory_mask: torch.Tensor | None = None, |
| tgt_key_padding_mask: torch.Tensor | None = None, |
| memory_key_padding_mask: torch.Tensor | None = None, |
| query_pos: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Forward pass through decoder layer. |
| |
| Args: |
| tgt: Target tensor of shape (tgt_len, batch, d_model) |
| memory: Memory tensor from encoder of shape (src_len, batch, d_model) |
| tgt_mask: Optional target attention mask |
| memory_mask: Optional memory attention mask |
| tgt_key_padding_mask: Optional target key padding mask |
| memory_key_padding_mask: Optional memory key padding mask |
| query_pos: Optional query position embeddings |
| |
| Returns: |
| torch.Tensor: Output tensor of same shape as target |
| """ |
| q = k = tgt if query_pos is None else tgt + query_pos |
| tgt2 = self.norm1(tgt) |
| tgt2, _ = self.self_attn( |
| q, k, tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask |
| ) |
| tgt = tgt + self.dropout1(tgt2) |
|
|
| tgt2 = self.norm2(tgt) |
| tgt2, _ = self.multihead_attn( |
| query=tgt2 if query_pos is None else tgt2 + query_pos, |
| key=memory, |
| value=memory, |
| attn_mask=memory_mask, |
| key_padding_mask=memory_key_padding_mask, |
| ) |
| tgt = tgt + self.dropout2(tgt2) |
|
|
| tgt2 = self.norm3(tgt) |
| tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt2)))) |
| tgt = tgt + self.dropout3(tgt2) |
|
|
| return tgt |
|
|
|
|
| class TransformerEncoder(nn.Module): |
| """Stack of transformer encoder layers. |
| |
| Composes multiple TransformerEncoderLayer modules with a final layer |
| normalization for encoding sequential input data. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| nhead: int, |
| num_encoder_layers: int = 6, |
| dim_feedforward: int = 2048, |
| dropout: float = 0.1, |
| ): |
| """Initialize transformer encoder. |
| |
| Args: |
| d_model: Model dimension |
| nhead: Number of attention heads |
| num_encoder_layers: Number of encoder layers |
| dim_feedforward: Feedforward network hidden dimension |
| dropout: Dropout probability |
| """ |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) |
| for _ in range(num_encoder_layers) |
| ]) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward( |
| self, |
| src: torch.Tensor, |
| mask: torch.Tensor | None = None, |
| src_key_padding_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Forward pass through all encoder layers. |
| |
| Args: |
| src: Input tensor of shape (seq_len, batch, d_model) |
| mask: Optional attention mask |
| src_key_padding_mask: Optional key padding mask |
| |
| Returns: |
| torch.Tensor: Encoded output tensor |
| """ |
| output = src |
|
|
| for layer in self.layers: |
| output = layer( |
| output, src_mask=mask, src_key_padding_mask=src_key_padding_mask |
| ) |
|
|
| return self.norm(output) |
|
|
|
|
| class TransformerDecoder(nn.Module): |
| """Stack of transformer decoder layers. |
| |
| Composes multiple TransformerDecoderLayer modules with a final layer |
| normalization for generating sequential output conditioned on memory. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| nhead: int, |
| num_decoder_layers: int = 6, |
| dim_feedforward: int = 2048, |
| dropout: float = 0.1, |
| ): |
| """Initialize transformer decoder. |
| |
| Args: |
| d_model: Model dimension |
| nhead: Number of attention heads |
| num_decoder_layers: Number of decoder layers |
| dim_feedforward: Feedforward network hidden dimension |
| dropout: Dropout probability |
| """ |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) |
| for _ in range(num_decoder_layers) |
| ]) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| def forward( |
| self, |
| tgt: torch.Tensor, |
| memory: torch.Tensor, |
| tgt_mask: torch.Tensor | None = None, |
| memory_mask: torch.Tensor | None = None, |
| tgt_key_padding_mask: torch.Tensor | None = None, |
| memory_key_padding_mask: torch.Tensor | None = None, |
| query_pos: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Forward pass through all decoder layers. |
| |
| Args: |
| tgt: Target tensor of shape (tgt_len, batch, d_model) |
| memory: Memory tensor from encoder of shape (src_len, batch, d_model) |
| tgt_mask: Optional target attention mask |
| memory_mask: Optional memory attention mask |
| tgt_key_padding_mask: Optional target key padding mask |
| memory_key_padding_mask: Optional memory key padding mask |
| query_pos: Optional query position embeddings |
| |
| Returns: |
| torch.Tensor: Decoded output tensor |
| """ |
| output = tgt |
|
|
| for layer in self.layers: |
| output = layer( |
| output, |
| memory, |
| tgt_mask=tgt_mask, |
| memory_mask=memory_mask, |
| tgt_key_padding_mask=tgt_key_padding_mask, |
| memory_key_padding_mask=memory_key_padding_mask, |
| query_pos=query_pos, |
| ) |
|
|
| return self.norm(output) |
|
|