Instructions to use BryanW/43.wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BryanW/43.wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BryanW/43.wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2024-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| ############################################################################## | |
| """Simple implementation of AutoEncoderKL for LTX v0.95.""" | |
| from einops import rearrange | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_outputs import AutoencoderKLOutput | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffnext.models.autoencoders.modeling_utils import DiagonalGaussianDistribution | |
| from diffnext.models.autoencoders.modeling_utils import DecoderOutput, TilingMixin | |
| class Conv3d(nn.Conv3d): | |
| """3D convolution.""" | |
| def __init__(self, *args, **kwargs): | |
| self.causal = kwargs.pop("causal", True) | |
| super(Conv3d, self).__init__(*args, **kwargs) | |
| self.padding = (0,) + tuple((_ // 2 for _ in self.kernel_size[1:])) | |
| self.pad1 = nn.ReplicationPad3d((0,) * 4 + (self.kernel_size[0] - 1, 0)) | |
| self.pad2 = nn.ReplicationPad3d((0,) * 4 + (self.pad1.padding[-2] // 2,) * 2) | |
| self.pad1 = nn.Identity() if self.kernel_size[0] == 1 else self.pad1 | |
| self.pad2 = nn.Identity() if self.kernel_size[0] == 1 else self.pad2 | |
| def forward(self, x): | |
| return super().forward(self.pad1(x) if self.causal else self.pad2(x)) | |
| class RMSNorm(nn.Module): | |
| """RMS normalization.""" | |
| def forward(self, x): | |
| # Enforce high precision RMS to avoid float16 underflow. | |
| return x.mul(x.float().square().mean(-1, True).add_(1e-8).rsqrt().to(x.dtype)) | |
| class TimeEmbed(nn.Module): | |
| """Time embedding layer.""" | |
| def __init__(self, embed_dim, freq_dim=256): | |
| super(TimeEmbed, self).__init__() | |
| self.timestep_proj = nn.Module() | |
| self.timestep_proj.fc1 = nn.Linear(freq_dim, embed_dim) | |
| self.timestep_proj.fc2 = nn.Linear(embed_dim, embed_dim) | |
| self.freq_dim, self.time_freq = freq_dim, None | |
| def get_freq_embed(self, timestep) -> torch.Tensor: | |
| if self.time_freq is None: | |
| dim, log_theta = self.freq_dim // 2, 9.210340371976184 # math.log(10000) | |
| freq = torch.arange(dim, dtype=torch.float32, device=timestep.device) | |
| self.time_freq = freq.mul(-log_theta / dim).exp().unsqueeze_(0) | |
| emb = timestep.unsqueeze(-1).float() * self.time_freq | |
| return torch.cat([emb.cos(), emb.sin()], dim=-1).to(dtype=timestep.dtype) | |
| def forward(self, temb) -> torch.Tensor: | |
| x = self.get_freq_embed(temb) if temb.dim() == 1 else temb | |
| return self.timestep_proj.fc2(nn.functional.silu(self.timestep_proj.fc1(x))) | |
| class ResBlock(nn.Module): | |
| """Resnet block.""" | |
| def __init__(self, dim, out_dim, causal=True): | |
| super(ResBlock, self).__init__() | |
| self.norm1, self.norm2 = RMSNorm(), RMSNorm() | |
| self.conv1 = Conv3d(dim, out_dim, 3, causal=causal) | |
| self.conv2 = Conv3d(out_dim, out_dim, 3, causal=causal) | |
| self.nonlinearity, self.dropout = nn.SiLU(), nn.Dropout(0, inplace=True) | |
| self.scale_shift_table = None if causal else nn.Parameter(torch.randn(4, dim) / dim**0.5) | |
| def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor: | |
| shortcut, stats = x, [] | |
| if self.scale_shift_table is not None: | |
| stats = temb.add(self.scale_shift_table.view(1, -1))[..., None, None, None].chunk(4, 1) | |
| x = self.norm1(x.movedim(1, -1)).movedim(-1, 1) | |
| x = x.mul(1 + stats[1]).add_(stats[0]) if stats else x | |
| x = self.conv1(self.nonlinearity(x)) | |
| x = self.norm2(x.movedim(1, -1)).movedim(-1, 1) | |
| x = x.mul(1 + stats[3]).add_(stats[2]) if stats else x | |
| return self.conv2(self.dropout(self.nonlinearity(x))).add_(shortcut) | |
| class MidBlock(nn.Module): | |
| """UNet mid block.""" | |
| def __init__(self, dim, depth=1, causal=True): | |
| super(MidBlock, self).__init__() | |
| self.time_embed = None if causal else TimeEmbed(dim * 4) | |
| self.resnets = nn.ModuleList(ResBlock(dim, dim, causal=causal) for _ in range(depth)) | |
| def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor: | |
| temb = self.time_embed(temb) if self.time_embed else None | |
| for resnet in self.resnets: | |
| x = resnet(x, temb) | |
| return x | |
| class Downsample(nn.Module): | |
| """Residual downsample layer.""" | |
| def __init__(self, dim, out_dim, stride, causal=True): | |
| super(Downsample, self).__init__() | |
| self.stride = stride = stride if isinstance(stride, (tuple, list)) else (stride,) * 3 | |
| self.group_size = (dim * torch.Size(stride).numel()) // out_dim | |
| self.pad_t, conv_dim = stride[0] - 1, out_dim // torch.Size(stride).numel() | |
| self.conv = Conv3d(dim, conv_dim, 3, 1, causal=causal) | |
| self.patch_args = {"r": stride[0], "p": stride[1], "q": stride[2]} | |
| self.patch_args["pattern"] = "b c (t r) (h p) (w q) -> b (c r p q) t h w" | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = nn.functional.pad(x, (0,) * 4 + (self.pad_t, 0), "replicate") if self.pad_t else x | |
| shortcut = rearrange(x, **self.patch_args).unflatten(1, (-1, self.group_size)).mean(dim=2) | |
| return rearrange(self.conv(x), **self.patch_args).add_(shortcut) | |
| class Upsample(nn.Module): | |
| """Residual upsample layer.""" | |
| def __init__(self, dim, out_dim, stride, causal=False): | |
| super(Upsample, self).__init__() | |
| self.stride = stride = stride if isinstance(stride, (tuple, list)) else (stride,) * 3 | |
| self.repeats = (out_dim * torch.Size(stride).numel()) // dim | |
| self.slice_t, conv_dim = stride[0] - 1, out_dim * torch.Size(stride).numel() | |
| self.conv = Conv3d(dim, conv_dim, 3, 1, causal=causal) | |
| self.patch_args = {"r": stride[0], "p": stride[1], "q": stride[2]} | |
| self.patch_args["pattern"] = "b (c r p q) t h w -> b c (t r) (h p) (w q)" | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| shortcut = rearrange(x, **self.patch_args).repeat(1, self.repeats, 1, 1, 1) | |
| x = rearrange(self.conv(x), **self.patch_args) | |
| x = x[:, :, self.slice_t :] if self.slice_t else x | |
| return x.add_(shortcut[:, :, self.slice_t :] if self.slice_t else shortcut) | |
| class DownBlock(nn.Module): | |
| """Downsample block.""" | |
| def __init__(self, dim, out_dim, depth=1, causal=True, downsample=""): | |
| super(DownBlock, self).__init__() | |
| self.resnets, self.downsamplers = nn.ModuleList(), nn.ModuleList() | |
| for _ in range(depth): | |
| self.resnets.append(ResBlock(dim, dim, causal=causal)) | |
| for _ in range(1 if downsample else 0): | |
| stride = {"spatial": (1, 2, 2), "temporal": (2, 1, 1), "spatiotemporal": 2}[downsample] | |
| self.downsamplers.append(Downsample(dim, out_dim, stride, causal=causal)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for resnet in self.resnets: | |
| x = resnet(x) | |
| for downsampler in self.downsamplers: | |
| x = downsampler(x) | |
| return x | |
| class UpBlock(nn.Module): | |
| """Upsample block.""" | |
| def __init__(self, dim, out_dim, depth=1, causal=False, upscale_factor=2): | |
| super(UpBlock, self).__init__() | |
| self.time_embed = TimeEmbed(out_dim * 4) | |
| self.resnets, self.upsamplers = nn.ModuleList(), nn.ModuleList() | |
| for _ in range(1 if upscale_factor > 1 else 0): | |
| self.upsamplers.append(Upsample(dim, out_dim, 2, causal=causal)) | |
| for _ in range(depth): | |
| self.resnets.append(ResBlock(out_dim, out_dim, causal=causal)) | |
| def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor: | |
| for upsampler in self.upsamplers: | |
| x = upsampler(x) | |
| temb = self.time_embed(temb) | |
| for resnet in self.resnets: | |
| x = resnet(x, temb) | |
| return x | |
| class Encoder(nn.Module): | |
| """VAE encoder.""" | |
| def __init__(self, dim, out_dim, block_dims, block_depths, patch_size=4): | |
| super(Encoder, self).__init__() | |
| self.patch_args = {"p": patch_size, "q": patch_size} | |
| downsample_type = ["spatial", "temporal", "spatiotemporal", "spatiotemporal"] | |
| self.conv_in = Conv3d(dim * patch_size**2, block_dims[0], 3, 1) | |
| self.down_blocks = nn.ModuleList() | |
| for i, (in_dim, depth, down) in enumerate(zip(block_dims, block_depths, downsample_type)): | |
| blk = DownBlock(in_dim, block_dims[i + 1], depth, downsample=down) | |
| self.down_blocks.append(blk) | |
| self.mid_block = MidBlock(block_dims[-1], block_depths[-1]) | |
| self.norm_out, self.conv_act = RMSNorm(), nn.SiLU() | |
| self.conv_out = Conv3d(block_dims[-1], out_dim + 1, 3, 1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = rearrange(x, "b c t (h p) (w q) -> b (c q p) t h w", **self.patch_args) | |
| x = self.conv_in(x) | |
| for down_block in self.down_blocks: | |
| x = down_block(x) | |
| x = self.mid_block(x) | |
| x = self.norm_out(x.movedim(1, -1)).movedim(-1, 1) | |
| return self.conv_out(self.conv_act(x)) | |
| class Decoder(nn.Module): | |
| """VAE decoder.""" | |
| def __init__(self, dim, out_dim, block_dims, block_depths, patch_size=4): | |
| super(Decoder, self).__init__() | |
| block_dims = tuple(reversed(block_dims)) | |
| self.patch_args = {"p": patch_size, "q": patch_size} | |
| self.conv_in = Conv3d(dim, block_dims[0], 3, 1, causal=False) | |
| self.mid_block = MidBlock(block_dims[0], block_depths[-1], causal=False) | |
| self.up_blocks = nn.ModuleList([]) | |
| for in_dim, depth in zip(block_dims, block_depths[:-1]): | |
| self.up_blocks.append(UpBlock(in_dim, in_dim // 2, depth, upscale_factor=2)) | |
| self.norm_out, self.conv_act = RMSNorm(), nn.SiLU() | |
| self.conv_out = Conv3d(block_dims[-1], out_dim * patch_size**2, 3, 1, causal=False) | |
| self.time_embed = TimeEmbed(block_dims[-1] * 2) | |
| self.scale_shift_table = nn.Parameter(torch.randn(2, block_dims[-1])) | |
| self.timestep_scale = nn.Parameter(torch.tensor(1000, dtype=torch.float32)) | |
| def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor: | |
| x = self.conv_in(x) | |
| temb = self.time_embed.get_freq_embed(temb * self.timestep_scale) | |
| x = self.mid_block(x, temb) | |
| for up_block in self.up_blocks: | |
| x = up_block(x, temb) | |
| x = self.norm_out(x.movedim(1, -1)).movedim(-1, 1) | |
| temb = self.time_embed(temb) | |
| stats = temb.add(self.scale_shift_table.view(1, -1))[..., None, None, None].chunk(2, 1) | |
| x = x.mul(1 + stats[1]).add_(stats[0]) | |
| x = self.conv_out(self.conv_act(x)) | |
| return rearrange(x, "b (c q p) t h w -> b c t (h p) (w q)", **self.patch_args) | |
| class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, TilingMixin): | |
| """AutoEncoder KL.""" | |
| def __init__( | |
| self, | |
| in_channels=3, | |
| out_channels=3, | |
| down_block_types=("LTXVideoDownBlock3D",) * 4, | |
| block_out_channels=(128, 256, 512, 1024, 2048), | |
| layers_per_block=(4, 6, 6, 2, 2), | |
| decoder_block_out_channels=(128, 256, 512, 1024), | |
| decoder_layers_per_block=(5, 5, 5, 5), | |
| act_fn="silu", | |
| latent_channels=128, | |
| sample_size=1024, | |
| scaling_factor=1.0, | |
| shift_factor=None, | |
| latents_mean=None, | |
| latents_std=None, | |
| patch_size=4, | |
| ): | |
| super(AutoencoderKLLTXVideo, self).__init__() | |
| TilingMixin.__init__(self, sample_min_t=249, latent_min_t=32, sample_ovr_t=1) | |
| channels, layers = block_out_channels, layers_per_block | |
| self.encoder = Encoder(in_channels, latent_channels, channels, layers) | |
| channels, layers = decoder_block_out_channels, decoder_layers_per_block | |
| self.decoder = Decoder(latent_channels, out_channels, channels, layers) | |
| self.register_buffer("shift_factors", torch.zeros(latents_mean) if latents_mean else None) | |
| self.register_buffer("scaling_factors", torch.ones(latents_std) if latents_std else None) | |
| self.latent_dist = DiagonalGaussianDistribution | |
| def scale_(self, x) -> torch.Tensor: | |
| """Scale the input latents.""" | |
| if self.shift_factors is not None: | |
| return x.sub_(self.shift_factors).mul_(self.scaling_factors) | |
| x.add_(-self.config.shift_factor) if self.config.shift_factor else None | |
| return x.mul_(self.config.scaling_factor) | |
| def unscale_(self, x) -> torch.Tensor: | |
| """Unscale the input latents.""" | |
| if self.shift_factors is not None: | |
| return x.div_(self.scaling_factors).add_(self.shift_factors) | |
| x.mul_(1 / self.config.scaling_factor) | |
| return x.add_(self.config.shift_factor) if self.config.shift_factor else x | |
| def encode(self, x) -> AutoencoderKLOutput: | |
| """Encode the input samples.""" | |
| z = self.tiled_encoder(self.forward(x)) | |
| posterior = self.latent_dist(z) | |
| return AutoencoderKLOutput(latent_dist=posterior) | |
| def decode(self, z, temb: torch.Tensor = None) -> DecoderOutput: | |
| """Decode the input latents.""" | |
| if temb is None: | |
| temb = torch.tensor([0] * z.size(0), dtype=z.dtype, device=z.device) | |
| extra_dim = 2 if z.dim() == 4 else None | |
| z = z.unsqueeze_(extra_dim) if extra_dim is not None else z | |
| x = self.tiled_decoder(self.forward(z), temb=temb) | |
| x = x.squeeze_(extra_dim) if extra_dim is not None else x | |
| return DecoderOutput(sample=x) | |
| def forward(self, x): # NOOP. | |
| return x | |