File size: 13,797 Bytes
47e954c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
"""Neural network modules for the ACT (Action Chunking Transformer) model.

This module provides the core building blocks for the ACT architecture including
image encoders, positional encodings, and transformer encoder/decoder layers.
These components are designed to work together for robot manipulation tasks.
"""

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


class ACTImageEncoder(nn.Module):
    """Encode images using ResNet backbone with spatial position embeddings.

    Maintains spatial dimensions and provides learnable position embeddings
    similar to DETR's backbone implementation. Uses a pretrained ResNet18
    backbone with projection layers for feature extraction.
    """

    def __init__(self, output_dim: int = 256):
        """Initialize the image encoder.

        Args:
            output_dim: Output feature dimension after projection
        """
        super().__init__()
        # Use pretrained ResNet but remove final layers
        self.backbone = self._build_backbone()
        self.proj = nn.Conv2d(512, output_dim, kernel_size=1)  # Project to output_dim

        # Position embeddings should match output_dim
        self.row_embed = nn.Embedding(50, output_dim // 2)  # Half size
        self.col_embed = nn.Embedding(50, output_dim // 2)  # Half size
        self.reset_parameters()

    def _build_backbone(self) -> nn.Module:
        """Build backbone CNN, removing avgpool and fc layers.

        Returns:
            nn.Module: ResNet18 backbone without classification layers
        """
        resnet = models.get_model("resnet18", weights="DEFAULT")
        return nn.Sequential(*list(resnet.children())[:-2])

    def reset_parameters(self) -> None:
        """Initialize position embeddings with uniform distribution."""
        nn.init.uniform_(self.row_embed.weight)
        nn.init.uniform_(self.col_embed.weight)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through image encoder.

        Args:
            x: Image tensor of shape (batch, channels, height, width)

        Returns:
            tuple[torch.Tensor, torch.Tensor]:
                - features: Encoded features of shape (batch, output_dim, height, width)
                - pos: Position embeddings of shape (batch, output_dim, height, width)
        """
        # Extract features
        x = self.backbone(x)
        features = self.proj(x)  # Now [B, output_dim, H, W]

        # Create position embeddings
        h, w = features.shape[-2:]
        i = torch.arange(w, device=x.device)
        j = torch.arange(h, device=x.device)
        x_emb = self.col_embed(i)  # [W, output_dim//2]
        y_emb = self.row_embed(j)  # [H, output_dim//2]

        pos = (
            torch.cat(
                [
                    x_emb.unsqueeze(0).repeat(h, 1, 1),
                    y_emb.unsqueeze(1).repeat(1, w, 1),
                ],
                dim=-1,
            )
            .permute(2, 0, 1)
            .unsqueeze(0)
        )  # [1, output_dim, H, W]

        pos = pos.repeat(x.shape[0], 1, 1, 1)  # [B, output_dim, H, W]

        return features, pos


class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding for transformer sequences.

    Implements the standard sinusoidal positional encoding as described
    in "Attention Is All You Need" (Vaswani et al., 2017).
    """

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        """Initialize positional encoding.

        Args:
            d_model: Model dimension
            dropout: Dropout probability
            max_len: Maximum sequence length to support
        """
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # [seq_len, batch, d_model]
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Add positional encoding to input tensor.

        Args:
            x: Input tensor of shape (seq_len, batch, d_model)

        Returns:
            torch.Tensor: Input with positional encoding added
        """
        x = x + self.pe[: x.size(0), :, :]  # [seq_len, batch, d_model]
        return self.dropout(x)


class TransformerEncoderLayer(nn.Module):
    """Single transformer encoder layer with pre-layer normalization.

    Implements a transformer encoder layer following the pre-norm architecture
    with multi-head self-attention and position-wise feed-forward network.
    """

    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        """Initialize transformer encoder layer.

        Args:
            d_model: Model dimension
            nhead: Number of attention heads
            dim_feedforward: Feedforward network hidden dimension
            dropout: Dropout probability
        """
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(
        self,
        src: torch.Tensor,
        src_mask: torch.Tensor | None = None,
        src_key_padding_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass through encoder layer.

        Args:
            src: Input tensor of shape (seq_len, batch, d_model)
            src_mask: Optional attention mask
            src_key_padding_mask: Optional key padding mask

        Returns:
            torch.Tensor: Output tensor of same shape as input
        """
        src2 = self.norm1(src)
        src2, _ = self.self_attn(
            src2, src2, src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
        )
        src = src + self.dropout1(src2)

        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src2))))
        src = src + self.dropout2(src2)

        return src


class TransformerDecoderLayer(nn.Module):
    """Single transformer decoder layer with cross-attention.

    Implements a transformer decoder layer with masked self-attention,
    cross-attention to encoder memory, and position-wise feed-forward network.
    Supports query position embeddings for object detection style architectures.
    """

    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        """Initialize transformer decoder layer.

        Args:
            d_model: Model dimension
            nhead: Number of attention heads
            dim_feedforward: Feedforward network hidden dimension
            dropout: Dropout probability
        """
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: torch.Tensor | None = None,
        memory_mask: torch.Tensor | None = None,
        tgt_key_padding_mask: torch.Tensor | None = None,
        memory_key_padding_mask: torch.Tensor | None = None,
        query_pos: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass through decoder layer.

        Args:
            tgt: Target tensor of shape (tgt_len, batch, d_model)
            memory: Memory tensor from encoder of shape (src_len, batch, d_model)
            tgt_mask: Optional target attention mask
            memory_mask: Optional memory attention mask
            tgt_key_padding_mask: Optional target key padding mask
            memory_key_padding_mask: Optional memory key padding mask
            query_pos: Optional query position embeddings

        Returns:
            torch.Tensor: Output tensor of same shape as target
        """
        q = k = tgt if query_pos is None else tgt + query_pos
        tgt2 = self.norm1(tgt)
        tgt2, _ = self.self_attn(
            q, k, tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
        )
        tgt = tgt + self.dropout1(tgt2)

        tgt2 = self.norm2(tgt)
        tgt2, _ = self.multihead_attn(
            query=tgt2 if query_pos is None else tgt2 + query_pos,
            key=memory,
            value=memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask,
        )
        tgt = tgt + self.dropout2(tgt2)

        tgt2 = self.norm3(tgt)
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt2))))
        tgt = tgt + self.dropout3(tgt2)

        return tgt


class TransformerEncoder(nn.Module):
    """Stack of transformer encoder layers.

    Composes multiple TransformerEncoderLayer modules with a final layer
    normalization for encoding sequential input data.
    """

    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_encoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        """Initialize transformer encoder.

        Args:
            d_model: Model dimension
            nhead: Number of attention heads
            num_encoder_layers: Number of encoder layers
            dim_feedforward: Feedforward network hidden dimension
            dropout: Dropout probability
        """
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(
        self,
        src: torch.Tensor,
        mask: torch.Tensor | None = None,
        src_key_padding_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass through all encoder layers.

        Args:
            src: Input tensor of shape (seq_len, batch, d_model)
            mask: Optional attention mask
            src_key_padding_mask: Optional key padding mask

        Returns:
            torch.Tensor: Encoded output tensor
        """
        output = src

        for layer in self.layers:
            output = layer(
                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
            )

        return self.norm(output)


class TransformerDecoder(nn.Module):
    """Stack of transformer decoder layers.

    Composes multiple TransformerDecoderLayer modules with a final layer
    normalization for generating sequential output conditioned on memory.
    """

    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
    ):
        """Initialize transformer decoder.

        Args:
            d_model: Model dimension
            nhead: Number of attention heads
            num_decoder_layers: Number of decoder layers
            dim_feedforward: Feedforward network hidden dimension
            dropout: Dropout probability
        """
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
            for _ in range(num_decoder_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: torch.Tensor | None = None,
        memory_mask: torch.Tensor | None = None,
        tgt_key_padding_mask: torch.Tensor | None = None,
        memory_key_padding_mask: torch.Tensor | None = None,
        query_pos: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass through all decoder layers.

        Args:
            tgt: Target tensor of shape (tgt_len, batch, d_model)
            memory: Memory tensor from encoder of shape (src_len, batch, d_model)
            tgt_mask: Optional target attention mask
            memory_mask: Optional memory attention mask
            tgt_key_padding_mask: Optional target key padding mask
            memory_key_padding_mask: Optional memory key padding mask
            query_pos: Optional query position embeddings

        Returns:
            torch.Tensor: Decoded output tensor
        """
        output = tgt

        for layer in self.layers:
            output = layer(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                query_pos=query_pos,
            )

        return self.norm(output)