Instructions to use ViTeX-Bench/ViTeX-Edit-14B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use ViTeX-Bench/ViTeX-Edit-14B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("ViTeX-Bench/ViTeX-Edit-14B", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| from einops import rearrange | |
| import torch.nn.functional as F | |
| from .ltx2_video_vae import LTX2VideoEncoder | |
| class PixelShuffleND(torch.nn.Module): | |
| """ | |
| N-dimensional pixel shuffle operation for upsampling tensors. | |
| Args: | |
| dims (int): Number of dimensions to apply pixel shuffle to. | |
| - 1: Temporal (e.g., frames) | |
| - 2: Spatial (e.g., height and width) | |
| - 3: Spatiotemporal (e.g., depth, height, width) | |
| upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension. | |
| For dims=1, only the first value is used. | |
| For dims=2, the first two values are used. | |
| For dims=3, all three values are used. | |
| The input tensor is rearranged so that the channel dimension is split into | |
| smaller channels and upscaling factors, and the upscaling factors are moved | |
| into the corresponding spatial/temporal dimensions. | |
| Note: | |
| This operation is equivalent to the patchifier operation in for the models. Consider | |
| using this class instead. | |
| """ | |
| def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)): | |
| super().__init__() | |
| assert dims in [1, 2, 3], "dims must be 1, 2, or 3" | |
| self.dims = dims | |
| self.upscale_factors = upscale_factors | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.dims == 3: | |
| return rearrange( | |
| x, | |
| "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", | |
| p1=self.upscale_factors[0], | |
| p2=self.upscale_factors[1], | |
| p3=self.upscale_factors[2], | |
| ) | |
| elif self.dims == 2: | |
| return rearrange( | |
| x, | |
| "b (c p1 p2) h w -> b c (h p1) (w p2)", | |
| p1=self.upscale_factors[0], | |
| p2=self.upscale_factors[1], | |
| ) | |
| elif self.dims == 1: | |
| return rearrange( | |
| x, | |
| "b (c p1) f h w -> b c (f p1) h w", | |
| p1=self.upscale_factors[0], | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported dims: {self.dims}") | |
| class ResBlock(torch.nn.Module): | |
| """ | |
| Residual block with two convolutional layers, group normalization, and SiLU activation. | |
| Args: | |
| channels (int): Number of input and output channels. | |
| mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels` | |
| if not specified. | |
| dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3. | |
| """ | |
| def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3): | |
| super().__init__() | |
| if mid_channels is None: | |
| mid_channels = channels | |
| conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d | |
| self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1) | |
| self.norm1 = torch.nn.GroupNorm(32, mid_channels) | |
| self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1) | |
| self.norm2 = torch.nn.GroupNorm(32, channels) | |
| self.activation = torch.nn.SiLU() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = self.conv1(x) | |
| x = self.norm1(x) | |
| x = self.activation(x) | |
| x = self.conv2(x) | |
| x = self.norm2(x) | |
| x = self.activation(x + residual) | |
| return x | |
| class BlurDownsample(torch.nn.Module): | |
| """ | |
| Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. | |
| Applies only on H,W. Works for dims=2 or dims=3 (per-frame). | |
| """ | |
| def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None: | |
| super().__init__() | |
| assert dims in (2, 3) | |
| assert isinstance(stride, int) | |
| assert stride >= 1 | |
| assert kernel_size >= 3 | |
| assert kernel_size % 2 == 1 | |
| self.dims = dims | |
| self.stride = stride | |
| self.kernel_size = kernel_size | |
| # 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from | |
| # the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and | |
| # provides a smooth approximation of a Gaussian filter (often called a "binomial filter"). | |
| # The 2D kernel is constructed as the outer product and normalized. | |
| k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)]) | |
| k2d = k[:, None] @ k[None, :] | |
| k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size) | |
| self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.stride == 1: | |
| return x | |
| if self.dims == 2: | |
| return self._apply_2d(x) | |
| else: | |
| # dims == 3: apply per-frame on H,W | |
| b, _, f, _, _ = x.shape | |
| x = rearrange(x, "b c f h w -> (b f) c h w") | |
| x = self._apply_2d(x) | |
| h2, w2 = x.shape[-2:] | |
| x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) | |
| return x | |
| def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor: | |
| c = x2d.shape[1] | |
| weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise | |
| x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) | |
| return x2d | |
| def _rational_for_scale(scale: float) -> Tuple[int, int]: | |
| mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} | |
| if float(scale) not in mapping: | |
| raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}") | |
| return mapping[float(scale)] | |
| class SpatialRationalResampler(torch.nn.Module): | |
| """ | |
| Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased | |
| downsample by 'den' using fixed blur + stride. Operates on H,W only. | |
| For dims==3, work per-frame for spatial scaling (temporal axis untouched). | |
| Args: | |
| mid_channels (`int`): Number of intermediate channels for the convolution layer | |
| scale (`float`): Spatial scaling factor. Supported values are: | |
| - 0.75: Downsample by 3/4 (reduce spatial size) | |
| - 1.5: Upsample by 3/2 (increase spatial size) | |
| - 2.0: Upsample by 2x (double spatial size) | |
| - 4.0: Upsample by 4x (quadruple spatial size) | |
| Any other value will raise a ValueError. | |
| """ | |
| def __init__(self, mid_channels: int, scale: float): | |
| super().__init__() | |
| self.scale = float(scale) | |
| self.num, self.den = _rational_for_scale(self.scale) | |
| self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1) | |
| self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) | |
| self.blur_down = BlurDownsample(dims=2, stride=self.den) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b, _, f, _, _ = x.shape | |
| x = rearrange(x, "b c f h w -> (b f) c h w") | |
| x = self.conv(x) | |
| x = self.pixel_shuffle(x) | |
| x = self.blur_down(x) | |
| x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) | |
| return x | |
| class LTX2LatentUpsampler(torch.nn.Module): | |
| """ | |
| Model to upsample VAE latents spatially and/or temporally. | |
| Args: | |
| in_channels (`int`): Number of channels in the input latent | |
| mid_channels (`int`): Number of channels in the middle layers | |
| num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) | |
| dims (`int`): Number of dimensions for convolutions (2 or 3) | |
| spatial_upsample (`bool`): Whether to spatially upsample the latent | |
| temporal_upsample (`bool`): Whether to temporally upsample the latent | |
| spatial_scale (`float`): Scale factor for spatial upsampling | |
| rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = 128, | |
| mid_channels: int = 1024, | |
| num_blocks_per_stage: int = 4, | |
| dims: int = 3, | |
| spatial_upsample: bool = True, | |
| temporal_upsample: bool = False, | |
| spatial_scale: float = 2.0, | |
| rational_resampler: bool = True, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.mid_channels = mid_channels | |
| self.num_blocks_per_stage = num_blocks_per_stage | |
| self.dims = dims | |
| self.spatial_upsample = spatial_upsample | |
| self.temporal_upsample = temporal_upsample | |
| self.spatial_scale = float(spatial_scale) | |
| self.rational_resampler = rational_resampler | |
| conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d | |
| self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1) | |
| self.initial_norm = torch.nn.GroupNorm(32, mid_channels) | |
| self.initial_activation = torch.nn.SiLU() | |
| self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]) | |
| if spatial_upsample and temporal_upsample: | |
| self.upsampler = torch.nn.Sequential( | |
| torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), | |
| PixelShuffleND(3), | |
| ) | |
| elif spatial_upsample: | |
| if rational_resampler: | |
| self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale) | |
| else: | |
| self.upsampler = torch.nn.Sequential( | |
| torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), | |
| PixelShuffleND(2), | |
| ) | |
| elif temporal_upsample: | |
| self.upsampler = torch.nn.Sequential( | |
| torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), | |
| PixelShuffleND(1), | |
| ) | |
| else: | |
| raise ValueError("Either spatial_upsample or temporal_upsample must be True") | |
| self.post_upsample_res_blocks = torch.nn.ModuleList( | |
| [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] | |
| ) | |
| self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1) | |
| def forward(self, latent: torch.Tensor) -> torch.Tensor: | |
| b, _, f, _, _ = latent.shape | |
| if self.dims == 2: | |
| x = rearrange(latent, "b c f h w -> (b f) c h w") | |
| x = self.initial_conv(x) | |
| x = self.initial_norm(x) | |
| x = self.initial_activation(x) | |
| for block in self.res_blocks: | |
| x = block(x) | |
| x = self.upsampler(x) | |
| for block in self.post_upsample_res_blocks: | |
| x = block(x) | |
| x = self.final_conv(x) | |
| x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) | |
| else: | |
| x = self.initial_conv(latent) | |
| x = self.initial_norm(x) | |
| x = self.initial_activation(x) | |
| for block in self.res_blocks: | |
| x = block(x) | |
| if self.temporal_upsample: | |
| x = self.upsampler(x) | |
| # remove the first frame after upsampling. | |
| # This is done because the first frame encodes one pixel frame. | |
| x = x[:, :, 1:, :, :] | |
| elif isinstance(self.upsampler, SpatialRationalResampler): | |
| x = self.upsampler(x) | |
| else: | |
| x = rearrange(x, "b c f h w -> (b f) c h w") | |
| x = self.upsampler(x) | |
| x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) | |
| for block in self.post_upsample_res_blocks: | |
| x = block(x) | |
| x = self.final_conv(x) | |
| return x | |
| def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor: | |
| """ | |
| Apply upsampling to the latent representation using the provided upsampler, | |
| with normalization and un-normalization based on the video encoder's per-channel statistics. | |
| Args: | |
| latent: Input latent tensor of shape [B, C, F, H, W]. | |
| video_encoder: VideoEncoder with per_channel_statistics for normalization. | |
| upsampler: LTX2LatentUpsampler module to perform upsampling. | |
| Returns: | |
| torch.Tensor: Upsampled and re-normalized latent tensor. | |
| """ | |
| latent = video_encoder.per_channel_statistics.un_normalize(latent) | |
| latent = upsampler(latent) | |
| latent = video_encoder.per_channel_statistics.normalize(latent) | |
| return latent | |