aliensmn's picture
Mirror from https://github.com/kijai/ComfyUI-WanVideoWrapper
cf812a0 verified
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class CausalConv1d(nn.Module):
def __init__(self,
chan_in,
chan_out,
kernel_size=3,
stride=1,
dilation=1,
pad_mode='replicate',
**kwargs):
super().__init__()
self.pad_mode = pad_mode
padding = (kernel_size - 1, 0) # T
self.time_causal_padding = padding
self.conv = nn.Conv1d(
chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
class MotionEncoder_tc(nn.Module):
def __init__(self,
in_dim: int,
hidden_dim: int,
num_heads=int,
need_global=True,
dtype=None,
device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.num_heads = num_heads
self.need_global = need_global
self.conv1_local = CausalConv1d(
in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
if need_global:
self.conv1_global = CausalConv1d(
in_dim, hidden_dim // 4, 3, stride=1)
self.norm1 = nn.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.act = nn.SiLU()
self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
if need_global:
self.final_linear = nn.Linear(hidden_dim, hidden_dim,
**factory_kwargs)
self.norm1 = nn.LayerNorm(
hidden_dim // 4,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm2 = nn.LayerNorm(
hidden_dim // 2,
elementwise_affine=False,
eps=1e-6,
**factory_kwargs)
self.norm3 = nn.LayerNorm(
hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
def forward(self, x):
x = rearrange(x, 'b t c -> b c t')
x_ori = x.clone()
b, c, t = x.shape
x = self.conv1_local(x)
x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
x = torch.cat([x, padding], dim=-2)
x_local = x.clone()
if not self.need_global:
return x_local
x = self.conv1_global(x_ori)
x = rearrange(x, 'b c t -> b t c')
x = self.norm1(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv2(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm2(x)
x = self.act(x)
x = rearrange(x, 'b t c -> b c t')
x = self.conv3(x)
x = rearrange(x, 'b c t -> b t c')
x = self.norm3(x)
x = self.act(x)
x = self.final_linear(x)
x = rearrange(x, '(b n) t c -> b t n c', b=b)
return x, x_local