File size: 1,132 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""2×2×2 时空压缩卷积。

将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为
4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。
"""

from __future__ import annotations

import torch
import torch.nn as nn


class TemporalCompress2x2x2(nn.Module):
    """``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。"""

    def __init__(self, dim: int = 768) -> None:
        super().__init__()
        self.dim = dim
        self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。

        中间排布:
            [B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W']
            -> [B, T'*H'*W', D] -> LayerNorm
        """
        b, t, h, w, d = x.shape
        x_in = x.permute(0, 4, 1, 2, 3).contiguous()  # [B, D, T, H, W]
        y = self.conv(x_in)
        bb, dd, t2, h2, w2 = y.shape
        y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd)
        return self.norm(y), (t2, h2, w2)