"""2×2×2 时空压缩卷积。 将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为 4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。 """ from __future__ import annotations import torch import torch.nn as nn class TemporalCompress2x2x2(nn.Module): """``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。""" def __init__(self, dim: int = 768) -> None: super().__init__() self.dim = dim self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0) self.norm = nn.LayerNorm(dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。 中间排布: [B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W'] -> [B, T'*H'*W', D] -> LayerNorm """ b, t, h, w, d = x.shape x_in = x.permute(0, 4, 1, 2, 3).contiguous() # [B, D, T, H, W] y = self.conv(x_in) bb, dd, t2, h2, w2 = y.shape y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd) return self.norm(y), (t2, h2, w2)