WJAD / src /wjad /modules /temporal_compress.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""2×2×2 时空压缩卷积。
将 8 帧 × 24 × 64 = 8×1536 = 12288 个 patch tokens 压缩为
4 × 12 × 32 = 1536 个视觉 tokens。维度保持 768。
"""
from __future__ import annotations
import torch
import torch.nn as nn
class TemporalCompress2x2x2(nn.Module):
"""``Conv3d(D, D, kernel=2, stride=2)`` 配合标准 LayerNorm。"""
def __init__(self, dim: int = 768) -> None:
super().__init__()
self.dim = dim
self.conv = nn.Conv3d(dim, dim, kernel_size=2, stride=2, padding=0)
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""输入 ``[B, T, H, W, D]``;输出 ``[B, T*H*W//8, D]``。
中间排布:
[B, T, H, W, D] -> [B, D, T, H, W] -> Conv3d -> [B, D, T', H', W']
-> [B, T'*H'*W', D] -> LayerNorm
"""
b, t, h, w, d = x.shape
x_in = x.permute(0, 4, 1, 2, 3).contiguous() # [B, D, T, H, W]
y = self.conv(x_in)
bb, dd, t2, h2, w2 = y.shape
y = y.permute(0, 2, 3, 4, 1).reshape(bb, t2 * h2 * w2, dd)
return self.norm(y), (t2, h2, w2)