| """2×2×2 时空压缩卷积。 |
| |
| 将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为 |
| 4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class TemporalCompress2x2x2(nn.Module): |
| """``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。""" |
|
|
| def __init__(self, dim: int = 768) -> None: |
| super().__init__() |
| self.dim = dim |
| self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0) |
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。 |
| |
| 中间排布: |
| [B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W'] |
| -> [B, T'*H'*W', D] -> LayerNorm |
| """ |
| b, t, h, w, d = x.shape |
| x_in = x.permute(0, 4, 1, 2, 3).contiguous() |
| y = self.conv(x_in) |
| bb, dd, t2, h2, w2 = y.shape |
| y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd) |
| return self.norm(y), (t2, h2, w2) |
|
|