| import functools
|
| from typing import Iterable, Union
|
|
|
| import torch
|
| from einops import rearrange, repeat
|
|
|
| import comfy.ops
|
| ops = comfy.ops.disable_weight_init
|
|
|
| from .diffusionmodules.model import (
|
| AttnBlock,
|
| Decoder,
|
| ResnetBlock,
|
| )
|
| from .diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
| from .attention import BasicTransformerBlock
|
|
|
| def partialclass(cls, *args, **kwargs):
|
| class NewCls(cls):
|
| __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
|
|
| return NewCls
|
|
|
|
|
| class VideoResBlock(ResnetBlock):
|
| def __init__(
|
| self,
|
| out_channels,
|
| *args,
|
| dropout=0.0,
|
| video_kernel_size=3,
|
| alpha=0.0,
|
| merge_strategy="learned",
|
| **kwargs,
|
| ):
|
| super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
| if video_kernel_size is None:
|
| video_kernel_size = [3, 1, 1]
|
| self.time_stack = ResBlock(
|
| channels=out_channels,
|
| emb_channels=0,
|
| dropout=dropout,
|
| dims=3,
|
| use_scale_shift_norm=False,
|
| use_conv=False,
|
| up=False,
|
| down=False,
|
| kernel_size=video_kernel_size,
|
| use_checkpoint=False,
|
| skip_t_emb=True,
|
| )
|
|
|
| self.merge_strategy = merge_strategy
|
| if self.merge_strategy == "fixed":
|
| self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
| elif self.merge_strategy == "learned":
|
| self.register_parameter(
|
| "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
| )
|
| else:
|
| raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
|
|
| def get_alpha(self, bs):
|
| if self.merge_strategy == "fixed":
|
| return self.mix_factor
|
| elif self.merge_strategy == "learned":
|
| return torch.sigmoid(self.mix_factor)
|
| else:
|
| raise NotImplementedError()
|
|
|
| def forward(self, x, temb, skip_video=False, timesteps=None):
|
| b, c, h, w = x.shape
|
| if timesteps is None:
|
| timesteps = b
|
|
|
| x = super().forward(x, temb)
|
|
|
| if not skip_video:
|
| x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
|
|
| x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
|
|
| x = self.time_stack(x, temb)
|
|
|
| alpha = self.get_alpha(bs=b // timesteps).to(x.device)
|
| x = alpha * x + (1.0 - alpha) * x_mix
|
|
|
| x = rearrange(x, "b c t h w -> (b t) c h w")
|
| return x
|
|
|
|
|
| class AE3DConv(ops.Conv2d):
|
| def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
| super().__init__(in_channels, out_channels, *args, **kwargs)
|
| if isinstance(video_kernel_size, Iterable):
|
| padding = [int(k // 2) for k in video_kernel_size]
|
| else:
|
| padding = int(video_kernel_size // 2)
|
|
|
| self.time_mix_conv = ops.Conv3d(
|
| in_channels=out_channels,
|
| out_channels=out_channels,
|
| kernel_size=video_kernel_size,
|
| padding=padding,
|
| )
|
|
|
| def forward(self, input, timesteps=None, skip_video=False):
|
| if timesteps is None:
|
| timesteps = input.shape[0]
|
| x = super().forward(input)
|
| if skip_video:
|
| return x
|
| x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
| x = self.time_mix_conv(x)
|
| return rearrange(x, "b c t h w -> (b t) c h w")
|
|
|
|
|
| class AttnVideoBlock(AttnBlock):
|
| def __init__(
|
| self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
| ):
|
| super().__init__(in_channels)
|
|
|
| self.time_mix_block = BasicTransformerBlock(
|
| dim=in_channels,
|
| n_heads=1,
|
| d_head=in_channels,
|
| checkpoint=False,
|
| ff_in=True,
|
| )
|
|
|
| time_embed_dim = self.in_channels * 4
|
| self.video_time_embed = torch.nn.Sequential(
|
| ops.Linear(self.in_channels, time_embed_dim),
|
| torch.nn.SiLU(),
|
| ops.Linear(time_embed_dim, self.in_channels),
|
| )
|
|
|
| self.merge_strategy = merge_strategy
|
| if self.merge_strategy == "fixed":
|
| self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
| elif self.merge_strategy == "learned":
|
| self.register_parameter(
|
| "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
| )
|
| else:
|
| raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
|
|
| def forward(self, x, timesteps=None, skip_time_block=False):
|
| if skip_time_block:
|
| return super().forward(x)
|
|
|
| if timesteps is None:
|
| timesteps = x.shape[0]
|
|
|
| x_in = x
|
| x = self.attention(x)
|
| h, w = x.shape[2:]
|
| x = rearrange(x, "b c h w -> b (h w) c")
|
|
|
| x_mix = x
|
| num_frames = torch.arange(timesteps, device=x.device)
|
| num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
| num_frames = rearrange(num_frames, "b t -> (b t)")
|
| t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
| emb = self.video_time_embed(t_emb)
|
| emb = emb[:, None, :]
|
| x_mix = x_mix + emb
|
|
|
| alpha = self.get_alpha().to(x.device)
|
| x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
| x = alpha * x + (1.0 - alpha) * x_mix
|
|
|
| x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
| x = self.proj_out(x)
|
|
|
| return x_in + x
|
|
|
| def get_alpha(
|
| self,
|
| ):
|
| if self.merge_strategy == "fixed":
|
| return self.mix_factor
|
| elif self.merge_strategy == "learned":
|
| return torch.sigmoid(self.mix_factor)
|
| else:
|
| raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
|
|
|
|
|
|
| def make_time_attn(
|
| in_channels,
|
| attn_type="vanilla",
|
| attn_kwargs=None,
|
| alpha: float = 0,
|
| merge_strategy: str = "learned",
|
| conv_op=ops.Conv2d,
|
| ):
|
| return partialclass(
|
| AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
| )
|
|
|
|
|
| class Conv2DWrapper(torch.nn.Conv2d):
|
| def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
| return super().forward(input)
|
|
|
|
|
| class VideoDecoder(Decoder):
|
| available_time_modes = ["all", "conv-only", "attn-only"]
|
|
|
| def __init__(
|
| self,
|
| *args,
|
| video_kernel_size: Union[int, list] = 3,
|
| alpha: float = 0.0,
|
| merge_strategy: str = "learned",
|
| time_mode: str = "conv-only",
|
| **kwargs,
|
| ):
|
| self.video_kernel_size = video_kernel_size
|
| self.alpha = alpha
|
| self.merge_strategy = merge_strategy
|
| self.time_mode = time_mode
|
| assert (
|
| self.time_mode in self.available_time_modes
|
| ), f"time_mode parameter has to be in {self.available_time_modes}"
|
|
|
| if self.time_mode != "attn-only":
|
| kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
| if self.time_mode not in ["conv-only", "only-last-conv"]:
|
| kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
| if self.time_mode not in ["attn-only", "only-last-conv"]:
|
| kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy)
|
|
|
| super().__init__(*args, **kwargs)
|
|
|
| def get_last_layer(self, skip_time_mix=False, **kwargs):
|
| if self.time_mode == "attn-only":
|
| raise NotImplementedError("TODO")
|
| else:
|
| return (
|
| self.conv_out.time_mix_conv.weight
|
| if not skip_time_mix
|
| else self.conv_out.weight
|
| )
|
|
|