# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class CausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode='replicate', **kwargs): super().__init__() self.pad_mode = pad_mode padding = (kernel_size - 1, 0) # T self.time_causal_padding = padding self.conv = nn.Conv1d( chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) def forward(self, x): x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) return self.conv(x) class MotionEncoder_tc(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, need_global=True, dtype=None, device=None): factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.num_heads = num_heads self.need_global = need_global self.conv1_local = CausalConv1d( in_dim, hidden_dim // 4 * num_heads, 3, stride=1) if need_global: self.conv1_global = CausalConv1d( in_dim, hidden_dim // 4, 3, stride=1) self.norm1 = nn.LayerNorm( hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.act = nn.SiLU() self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2) self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2) if need_global: self.final_linear = nn.Linear(hidden_dim, hidden_dim, **factory_kwargs) self.norm1 = nn.LayerNorm( hidden_dim // 4, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.norm2 = nn.LayerNorm( hidden_dim // 2, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.norm3 = nn.LayerNorm( hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) def forward(self, x): x = rearrange(x, 'b t c -> b c t') x_ori = x.clone() b, c, t = x.shape x = self.conv1_local(x) x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads) x = self.norm1(x) x = self.act(x) x = rearrange(x, 'b t c -> b c t') x = self.conv2(x) x = rearrange(x, 'b c t -> b t c') x = self.norm2(x) x = self.act(x) x = rearrange(x, 'b t c -> b c t') x = self.conv3(x) x = rearrange(x, 'b c t -> b t c') x = self.norm3(x) x = self.act(x) x = rearrange(x, '(b n) t c -> b t n c', b=b) padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) x = torch.cat([x, padding], dim=-2) x_local = x.clone() if not self.need_global: return x_local x = self.conv1_global(x_ori) x = rearrange(x, 'b c t -> b t c') x = self.norm1(x) x = self.act(x) x = rearrange(x, 'b t c -> b c t') x = self.conv2(x) x = rearrange(x, 'b c t -> b t c') x = self.norm2(x) x = self.act(x) x = rearrange(x, 'b t c -> b c t') x = self.conv3(x) x = rearrange(x, 'b c t -> b t c') x = self.norm3(x) x = self.act(x) x = self.final_linear(x) x = rearrange(x, '(b n) t c -> b t n c', b=b) return x, x_local