Instructions to use ViTeX-Bench/ViTeX-Edit-14B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use ViTeX-Bench/ViTeX-Edit-14B with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("ViTeX-Bench/ViTeX-Edit-14B", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from typing import Set, Tuple, Optional, List | |
| from enum import Enum | |
| import math | |
| import einops | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer | |
| class AudioProcessor(nn.Module): | |
| """Converts audio waveforms to log-mel spectrograms with optional resampling.""" | |
| def __init__( | |
| self, | |
| sample_rate: int = 16000, | |
| mel_bins: int = 64, | |
| mel_hop_length: int = 160, | |
| n_fft: int = 1024, | |
| ) -> None: | |
| super().__init__() | |
| self.sample_rate = sample_rate | |
| self.mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_fft=n_fft, | |
| win_length=n_fft, | |
| hop_length=mel_hop_length, | |
| f_min=0.0, | |
| f_max=sample_rate / 2.0, | |
| n_mels=mel_bins, | |
| window_fn=torch.hann_window, | |
| center=True, | |
| pad_mode="reflect", | |
| power=1.0, | |
| mel_scale="slaney", | |
| norm="slaney", | |
| ) | |
| def resample_waveform( | |
| self, | |
| waveform: torch.Tensor, | |
| source_rate: int, | |
| target_rate: int, | |
| ) -> torch.Tensor: | |
| """Resample waveform to target sample rate if needed.""" | |
| if source_rate == target_rate: | |
| return waveform | |
| resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) | |
| return resampled.to(device=waveform.device, dtype=waveform.dtype) | |
| def waveform_to_mel( | |
| self, | |
| waveform: torch.Tensor, | |
| waveform_sample_rate: int, | |
| ) -> torch.Tensor: | |
| """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].""" | |
| waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate) | |
| mel = self.mel_transform(waveform) | |
| mel = torch.log(torch.clamp(mel, min=1e-5)) | |
| mel = mel.to(device=waveform.device, dtype=waveform.dtype) | |
| return mel.permute(0, 1, 3, 2).contiguous() | |
| class AudioPatchifier(Patchifier): | |
| def __init__( | |
| self, | |
| patch_size: int, | |
| sample_rate: int = 16000, | |
| hop_length: int = 160, | |
| audio_latent_downsample_factor: int = 4, | |
| is_causal: bool = True, | |
| shift: int = 0, | |
| ): | |
| """ | |
| Patchifier tailored for spectrogram/audio latents. | |
| Args: | |
| patch_size: Number of mel bins combined into a single patch. This | |
| controls the resolution along the frequency axis. | |
| sample_rate: Original waveform sampling rate. Used to map latent | |
| indices back to seconds so downstream consumers can align audio | |
| and video cues. | |
| hop_length: Window hop length used for the spectrogram. Determines | |
| how many real-time samples separate two consecutive latent frames. | |
| audio_latent_downsample_factor: Ratio between spectrogram frames and | |
| latent frames; compensates for additional downsampling inside the | |
| VAE encoder. | |
| is_causal: When True, timing is shifted to account for causal | |
| receptive fields so timestamps do not peek into the future. | |
| shift: Integer offset applied to the latent indices. Enables | |
| constructing overlapping windows from the same latent sequence. | |
| """ | |
| self.hop_length = hop_length | |
| self.sample_rate = sample_rate | |
| self.audio_latent_downsample_factor = audio_latent_downsample_factor | |
| self.is_causal = is_causal | |
| self.shift = shift | |
| self._patch_size = (1, patch_size, patch_size) | |
| def patch_size(self) -> Tuple[int, int, int]: | |
| return self._patch_size | |
| def get_token_count(self, tgt_shape: AudioLatentShape) -> int: | |
| return tgt_shape.frames | |
| def _get_audio_latent_time_in_sec( | |
| self, | |
| start_latent: int, | |
| end_latent: int, | |
| dtype: torch.dtype, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Converts latent indices into real-time seconds while honoring causal | |
| offsets and the configured hop length. | |
| Args: | |
| start_latent: Inclusive start index inside the latent sequence. This | |
| sets the first timestamp returned. | |
| end_latent: Exclusive end index. Determines how many timestamps get | |
| generated. | |
| dtype: Floating-point dtype used for the returned tensor, allowing | |
| callers to control precision. | |
| device: Target device for the timestamp tensor. When omitted the | |
| computation occurs on CPU to avoid surprising GPU allocations. | |
| """ | |
| if device is None: | |
| device = torch.device("cpu") | |
| audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) | |
| audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor | |
| if self.is_causal: | |
| # Frame offset for causal alignment. | |
| # The "+1" ensures the timestamp corresponds to the first sample that is fully available. | |
| causal_offset = 1 | |
| audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) | |
| return audio_mel_frame * self.hop_length / self.sample_rate | |
| def _compute_audio_timings( | |
| self, | |
| batch_size: int, | |
| num_steps: int, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. | |
| This helper method underpins `get_patch_grid_bounds` for the audio patchifier. | |
| Args: | |
| batch_size: Number of sequences to broadcast the timings over. | |
| num_steps: Number of latent frames (time steps) to convert into timestamps. | |
| device: Device on which the resulting tensor should reside. | |
| """ | |
| resolved_device = device | |
| if resolved_device is None: | |
| resolved_device = torch.device("cpu") | |
| start_timings = self._get_audio_latent_time_in_sec( | |
| self.shift, | |
| num_steps + self.shift, | |
| torch.float32, | |
| resolved_device, | |
| ) | |
| start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) | |
| end_timings = self._get_audio_latent_time_in_sec( | |
| self.shift + 1, | |
| num_steps + self.shift + 1, | |
| torch.float32, | |
| resolved_device, | |
| ) | |
| end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) | |
| return torch.stack([start_timings, end_timings], dim=-1) | |
| def patchify( | |
| self, | |
| audio_latents: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` | |
| to derive timestamps for each latent frame based on the configured hop | |
| length and downsampling. | |
| Args: | |
| audio_latents: Latent tensor to patchify. | |
| Returns: | |
| Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the | |
| corresponding timing metadata when needed. | |
| """ | |
| audio_latents = einops.rearrange( | |
| audio_latents, | |
| "b c t f -> b t (c f)", | |
| ) | |
| return audio_latents | |
| def unpatchify( | |
| self, | |
| audio_latents: torch.Tensor, | |
| output_shape: AudioLatentShape, | |
| ) -> torch.Tensor: | |
| """ | |
| Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. | |
| Use `get_patch_grid_bounds` to recompute the timestamps that describe each | |
| frame's position in real time. | |
| Args: | |
| audio_latents: Latent tensor to unpatchify. | |
| output_shape: Shape of the unpatched output tensor. | |
| Returns: | |
| Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing | |
| metadata associated with the restored latents. | |
| """ | |
| # audio_latents shape: (batch, time, freq * channels) | |
| audio_latents = einops.rearrange( | |
| audio_latents, | |
| "b t (c f) -> b c t f", | |
| c=output_shape.channels, | |
| f=output_shape.mel_bins, | |
| ) | |
| return audio_latents | |
| def unpatchify_audio( | |
| self, | |
| audio_latents: torch.Tensor, | |
| channels: int, | |
| mel_bins: int | |
| ) -> torch.Tensor: | |
| audio_latents = einops.rearrange( | |
| audio_latents, | |
| "b t (c f) -> b c t f", | |
| c=channels, | |
| f=mel_bins, | |
| ) | |
| return audio_latents | |
| def get_patch_grid_bounds( | |
| self, | |
| output_shape: AudioLatentShape | VideoLatentShape, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Return the temporal bounds `[inclusive start, exclusive end)` for every | |
| patch emitted by `patchify`. For audio this corresponds to timestamps in | |
| seconds aligned with the original spectrogram grid. | |
| The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: | |
| - axis 1 (size 1) represents the temporal dimension | |
| - axis 3 (size 2) stores the `[start, end)` timestamps per patch | |
| Args: | |
| output_shape: Audio grid specification describing the number of time steps. | |
| device: Target device for the returned tensor. | |
| """ | |
| if not isinstance(output_shape, AudioLatentShape): | |
| raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") | |
| return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) | |
| class AttentionType(Enum): | |
| """Enum for specifying the attention mechanism type.""" | |
| VANILLA = "vanilla" | |
| LINEAR = "linear" | |
| NONE = "none" | |
| class AttnBlock(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| norm_type: NormType = NormType.GROUP, | |
| ) -> None: | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = build_normalization_layer(in_channels, normtype=norm_type) | |
| self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| h_ = x | |
| h_ = self.norm(h_) | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| # compute attention | |
| b, c, h, w = q.shape | |
| q = q.reshape(b, c, h * w).contiguous() | |
| q = q.permute(0, 2, 1).contiguous() # b,hw,c | |
| k = k.reshape(b, c, h * w).contiguous() # b,c,hw | |
| w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] | |
| w_ = w_ * (int(c) ** (-0.5)) | |
| w_ = torch.nn.functional.softmax(w_, dim=2) | |
| # attend to values | |
| v = v.reshape(b, c, h * w).contiguous() | |
| w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) | |
| h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] | |
| h_ = h_.reshape(b, c, h, w).contiguous() | |
| h_ = self.proj_out(h_) | |
| return x + h_ | |
| def make_attn( | |
| in_channels: int, | |
| attn_type: AttentionType = AttentionType.VANILLA, | |
| norm_type: NormType = NormType.GROUP, | |
| ) -> torch.nn.Module: | |
| match attn_type: | |
| case AttentionType.VANILLA: | |
| return AttnBlock(in_channels, norm_type=norm_type) | |
| case AttentionType.NONE: | |
| return torch.nn.Identity() | |
| case AttentionType.LINEAR: | |
| raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") | |
| case _: | |
| raise ValueError(f"Unknown attention type: {attn_type}") | |
| class CausalityAxis(Enum): | |
| """Enum for specifying the causality axis in causal convolutions.""" | |
| NONE = None | |
| WIDTH = "width" | |
| HEIGHT = "height" | |
| WIDTH_COMPATIBILITY = "width-compatibility" | |
| class CausalConv2d(torch.nn.Module): | |
| """ | |
| A causal 2D convolution. | |
| This layer ensures that the output at time `t` only depends on inputs | |
| at time `t` and earlier. It achieves this by applying asymmetric padding | |
| to the time dimension (width) before the convolution. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int | tuple[int, int], | |
| stride: int = 1, | |
| dilation: int | tuple[int, int] = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, | |
| ) -> None: | |
| super().__init__() | |
| self.causality_axis = causality_axis | |
| # Ensure kernel_size and dilation are tuples | |
| kernel_size = torch.nn.modules.utils._pair(kernel_size) | |
| dilation = torch.nn.modules.utils._pair(dilation) | |
| # Calculate padding dimensions | |
| pad_h = (kernel_size[0] - 1) * dilation[0] | |
| pad_w = (kernel_size[1] - 1) * dilation[1] | |
| # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) | |
| match self.causality_axis: | |
| case CausalityAxis.NONE: | |
| self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) | |
| case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: | |
| self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) | |
| case CausalityAxis.HEIGHT: | |
| self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) | |
| case _: | |
| raise ValueError(f"Invalid causality_axis: {causality_axis}") | |
| # The internal convolution layer uses no padding, as we handle it manually | |
| self.conv = torch.nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=stride, | |
| padding=0, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Apply causal padding before convolution | |
| x = F.pad(x, self.padding) | |
| return self.conv(x) | |
| def make_conv2d( | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int | tuple[int, int], | |
| stride: int = 1, | |
| padding: tuple[int, int, int, int] | None = None, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| causality_axis: CausalityAxis | None = None, | |
| ) -> torch.nn.Module: | |
| """ | |
| Create a 2D convolution layer that can be either causal or non-causal. | |
| Args: | |
| in_channels: Number of input channels | |
| out_channels: Number of output channels | |
| kernel_size: Size of the convolution kernel | |
| stride: Convolution stride | |
| padding: Padding (if None, will be calculated based on causal flag) | |
| dilation: Dilation rate | |
| groups: Number of groups for grouped convolution | |
| bias: Whether to use bias | |
| causality_axis: Dimension along which to apply causality. | |
| Returns: | |
| Either a regular Conv2d or CausalConv2d layer | |
| """ | |
| if causality_axis is not None: | |
| # For causal convolution, padding is handled internally by CausalConv2d | |
| return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) | |
| else: | |
| # For non-causal convolution, use symmetric padding if not specified | |
| if padding is None: | |
| padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) | |
| return torch.nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| dilation, | |
| groups, | |
| bias, | |
| ) | |
| LRELU_SLOPE = 0.1 | |
| class ResBlock1(torch.nn.Module): | |
| def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)): | |
| super(ResBlock1, self).__init__() | |
| self.convs1 = torch.nn.ModuleList( | |
| [ | |
| torch.nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[0], | |
| padding="same", | |
| ), | |
| torch.nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[1], | |
| padding="same", | |
| ), | |
| torch.nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[2], | |
| padding="same", | |
| ), | |
| ] | |
| ) | |
| self.convs2 = torch.nn.ModuleList( | |
| [ | |
| torch.nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=1, | |
| padding="same", | |
| ), | |
| torch.nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=1, | |
| padding="same", | |
| ), | |
| torch.nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=1, | |
| padding="same", | |
| ), | |
| ] | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for conv1, conv2 in zip(self.convs1, self.convs2, strict=True): | |
| xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) | |
| xt = conv1(xt) | |
| xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) | |
| xt = conv2(xt) | |
| x = xt + x | |
| return x | |
| class ResBlock2(torch.nn.Module): | |
| def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)): | |
| super(ResBlock2, self).__init__() | |
| self.convs = torch.nn.ModuleList( | |
| [ | |
| torch.nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[0], | |
| padding="same", | |
| ), | |
| torch.nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[1], | |
| padding="same", | |
| ), | |
| ] | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for conv in self.convs: | |
| xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) | |
| xt = conv(xt) | |
| x = xt + x | |
| return x | |
| class ResnetBlock(torch.nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| in_channels: int, | |
| out_channels: int | None = None, | |
| conv_shortcut: bool = False, | |
| dropout: float = 0.0, | |
| temb_channels: int = 512, | |
| norm_type: NormType = NormType.GROUP, | |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, | |
| ) -> None: | |
| super().__init__() | |
| self.causality_axis = causality_axis | |
| if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: | |
| raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) | |
| self.non_linearity = torch.nn.SiLU() | |
| self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) | |
| if temb_channels > 0: | |
| self.temb_proj = torch.nn.Linear(temb_channels, out_channels) | |
| self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) | |
| if self.in_channels != self.out_channels: | |
| if self.use_conv_shortcut: | |
| self.conv_shortcut = make_conv2d( | |
| in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis | |
| ) | |
| else: | |
| self.nin_shortcut = make_conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| temb: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| h = x | |
| h = self.norm1(h) | |
| h = self.non_linearity(h) | |
| h = self.conv1(h) | |
| if temb is not None: | |
| h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] | |
| h = self.norm2(h) | |
| h = self.non_linearity(h) | |
| h = self.dropout(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) | |
| return x + h | |
| class Downsample(torch.nn.Module): | |
| """ | |
| A downsampling layer that can use either a strided convolution | |
| or average pooling. Supports standard and causal padding for the | |
| convolutional mode. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| with_conv: bool, | |
| causality_axis: CausalityAxis = CausalityAxis.WIDTH, | |
| ) -> None: | |
| super().__init__() | |
| self.with_conv = with_conv | |
| self.causality_axis = causality_axis | |
| if self.causality_axis != CausalityAxis.NONE and not self.with_conv: | |
| raise ValueError("causality is only supported when `with_conv=True`.") | |
| if self.with_conv: | |
| # Do time downsampling here | |
| # no asymmetric padding in torch conv, must do it ourselves | |
| self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.with_conv: | |
| # Padding tuple is in the order: (left, right, top, bottom). | |
| match self.causality_axis: | |
| case CausalityAxis.NONE: | |
| pad = (0, 1, 0, 1) | |
| case CausalityAxis.WIDTH: | |
| pad = (2, 0, 0, 1) | |
| case CausalityAxis.HEIGHT: | |
| pad = (0, 1, 2, 0) | |
| case CausalityAxis.WIDTH_COMPATIBILITY: | |
| pad = (1, 0, 0, 1) | |
| case _: | |
| raise ValueError(f"Invalid causality_axis: {self.causality_axis}") | |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| else: | |
| # This branch is only taken if with_conv=False, which implies causality_axis is NONE. | |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) | |
| return x | |
| def build_downsampling_path( # noqa: PLR0913 | |
| *, | |
| ch: int, | |
| ch_mult: Tuple[int, ...], | |
| num_resolutions: int, | |
| num_res_blocks: int, | |
| resolution: int, | |
| temb_channels: int, | |
| dropout: float, | |
| norm_type: NormType, | |
| causality_axis: CausalityAxis, | |
| attn_type: AttentionType, | |
| attn_resolutions: Set[int], | |
| resamp_with_conv: bool, | |
| ) -> tuple[torch.nn.ModuleList, int]: | |
| """Build the downsampling path with residual blocks, attention, and downsampling layers.""" | |
| down_modules = torch.nn.ModuleList() | |
| curr_res = resolution | |
| in_ch_mult = (1, *tuple(ch_mult)) | |
| block_in = ch | |
| for i_level in range(num_resolutions): | |
| block = torch.nn.ModuleList() | |
| attn = torch.nn.ModuleList() | |
| block_in = ch * in_ch_mult[i_level] | |
| block_out = ch * ch_mult[i_level] | |
| for _ in range(num_res_blocks): | |
| block.append( | |
| ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_out, | |
| temb_channels=temb_channels, | |
| dropout=dropout, | |
| norm_type=norm_type, | |
| causality_axis=causality_axis, | |
| ) | |
| ) | |
| block_in = block_out | |
| if curr_res in attn_resolutions: | |
| attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) | |
| down = torch.nn.Module() | |
| down.block = block | |
| down.attn = attn | |
| if i_level != num_resolutions - 1: | |
| down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) | |
| curr_res = curr_res // 2 | |
| down_modules.append(down) | |
| return down_modules, block_in | |
| class Upsample(torch.nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| with_conv: bool, | |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, | |
| ) -> None: | |
| super().__init__() | |
| self.with_conv = with_conv | |
| self.causality_axis = causality_axis | |
| if self.with_conv: | |
| self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
| if self.with_conv: | |
| x = self.conv(x) | |
| # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. | |
| # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. | |
| # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], | |
| # So the output elements rely on the following windows: | |
| # 0: [-,-,0] | |
| # 1: [-,0,0] | |
| # 2: [0,0,1] | |
| # 3: [0,1,1] | |
| # 4: [1,1,2] | |
| # 5: [1,2,2] | |
| # Notice that the first and second elements in the output rely only on the first element in the input, | |
| # while all other elements rely on two elements in the input. | |
| # So we can drop the first element to undo the padding (rather than the last element). | |
| # This is a no-op for non-causal convolutions. | |
| match self.causality_axis: | |
| case CausalityAxis.NONE: | |
| pass # x remains unchanged | |
| case CausalityAxis.HEIGHT: | |
| x = x[:, :, 1:, :] | |
| case CausalityAxis.WIDTH: | |
| x = x[:, :, :, 1:] | |
| case CausalityAxis.WIDTH_COMPATIBILITY: | |
| pass # x remains unchanged | |
| case _: | |
| raise ValueError(f"Invalid causality_axis: {self.causality_axis}") | |
| return x | |
| def build_upsampling_path( # noqa: PLR0913 | |
| *, | |
| ch: int, | |
| ch_mult: Tuple[int, ...], | |
| num_resolutions: int, | |
| num_res_blocks: int, | |
| resolution: int, | |
| temb_channels: int, | |
| dropout: float, | |
| norm_type: NormType, | |
| causality_axis: CausalityAxis, | |
| attn_type: AttentionType, | |
| attn_resolutions: Set[int], | |
| resamp_with_conv: bool, | |
| initial_block_channels: int, | |
| ) -> tuple[torch.nn.ModuleList, int]: | |
| """Build the upsampling path with residual blocks, attention, and upsampling layers.""" | |
| up_modules = torch.nn.ModuleList() | |
| block_in = initial_block_channels | |
| curr_res = resolution // (2 ** (num_resolutions - 1)) | |
| for level in reversed(range(num_resolutions)): | |
| stage = torch.nn.Module() | |
| stage.block = torch.nn.ModuleList() | |
| stage.attn = torch.nn.ModuleList() | |
| block_out = ch * ch_mult[level] | |
| for _ in range(num_res_blocks + 1): | |
| stage.block.append( | |
| ResnetBlock( | |
| in_channels=block_in, | |
| out_channels=block_out, | |
| temb_channels=temb_channels, | |
| dropout=dropout, | |
| norm_type=norm_type, | |
| causality_axis=causality_axis, | |
| ) | |
| ) | |
| block_in = block_out | |
| if curr_res in attn_resolutions: | |
| stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) | |
| if level != 0: | |
| stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) | |
| curr_res *= 2 | |
| up_modules.insert(0, stage) | |
| return up_modules, block_in | |
| class PerChannelStatistics(nn.Module): | |
| """ | |
| Per-channel statistics for normalizing and denormalizing the latent representation. | |
| This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict. | |
| """ | |
| def __init__(self, latent_channels: int = 128) -> None: | |
| super().__init__() | |
| self.register_buffer("std-of-means", torch.empty(latent_channels)) | |
| self.register_buffer("mean-of-means", torch.empty(latent_channels)) | |
| def un_normalize(self, x: torch.Tensor) -> torch.Tensor: | |
| return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) | |
| def normalize(self, x: torch.Tensor) -> torch.Tensor: | |
| return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) | |
| LATENT_DOWNSAMPLE_FACTOR = 4 | |
| def build_mid_block( | |
| channels: int, | |
| temb_channels: int, | |
| dropout: float, | |
| norm_type: NormType, | |
| causality_axis: CausalityAxis, | |
| attn_type: AttentionType, | |
| add_attention: bool, | |
| ) -> torch.nn.Module: | |
| """Build the middle block with two ResNet blocks and optional attention.""" | |
| mid = torch.nn.Module() | |
| mid.block_1 = ResnetBlock( | |
| in_channels=channels, | |
| out_channels=channels, | |
| temb_channels=temb_channels, | |
| dropout=dropout, | |
| norm_type=norm_type, | |
| causality_axis=causality_axis, | |
| ) | |
| mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity() | |
| mid.block_2 = ResnetBlock( | |
| in_channels=channels, | |
| out_channels=channels, | |
| temb_channels=temb_channels, | |
| dropout=dropout, | |
| norm_type=norm_type, | |
| causality_axis=causality_axis, | |
| ) | |
| return mid | |
| def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor: | |
| """Run features through the middle block.""" | |
| features = mid.block_1(features, temb=None) | |
| features = mid.attn_1(features) | |
| return mid.block_2(features, temb=None) | |
| class LTX2AudioEncoder(torch.nn.Module): | |
| """ | |
| Encoder that compresses audio spectrograms into latent representations. | |
| The encoder uses a series of downsampling blocks with residual connections, | |
| attention mechanisms, and configurable causal convolutions. | |
| """ | |
| def __init__( # noqa: PLR0913 | |
| self, | |
| *, | |
| ch: int = 128, | |
| ch_mult: Tuple[int, ...] = (1, 2, 4), | |
| num_res_blocks: int = 2, | |
| attn_resolutions: Set[int] = set(), | |
| dropout: float = 0.0, | |
| resamp_with_conv: bool = True, | |
| in_channels: int = 2, | |
| resolution: int = 256, | |
| z_channels: int = 8, | |
| double_z: bool = True, | |
| attn_type: AttentionType = AttentionType.VANILLA, | |
| mid_block_add_attention: bool = False, | |
| norm_type: NormType = NormType.PIXEL, | |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, | |
| sample_rate: int = 16000, | |
| mel_hop_length: int = 160, | |
| n_fft: int = 1024, | |
| is_causal: bool = True, | |
| mel_bins: int = 64, | |
| **_ignore_kwargs, | |
| ) -> None: | |
| """ | |
| Initialize the Encoder. | |
| Args: | |
| Arguments are configuration parameters, loaded from the audio VAE checkpoint config | |
| (audio_vae.model.params.ddconfig): | |
| ch: Base number of feature channels used in the first convolution layer. | |
| ch_mult: Multiplicative factors for the number of channels at each resolution level. | |
| num_res_blocks: Number of residual blocks to use at each resolution level. | |
| attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention. | |
| resolution: Input spatial resolution of the spectrogram (height, width). | |
| z_channels: Number of channels in the latent representation. | |
| norm_type: Normalization layer type to use within the network (e.g., group, batch). | |
| causality_axis: Axis along which convolutions should be causal (e.g., time axis). | |
| sample_rate: Audio sample rate in Hz for the input signals. | |
| mel_hop_length: Hop length used when computing the mel spectrogram. | |
| n_fft: FFT size used to compute the spectrogram. | |
| mel_bins: Number of mel-frequency bins in the input spectrogram. | |
| in_channels: Number of channels in the input spectrogram tensor. | |
| double_z: If True, predict both mean and log-variance (doubling latent channels). | |
| is_causal: If True, use causal convolutions suitable for streaming setups. | |
| dropout: Dropout probability used in residual and mid blocks. | |
| attn_type: Type of attention mechanism to use in attention blocks. | |
| resamp_with_conv: If True, perform resolution changes using strided convolutions. | |
| mid_block_add_attention: If True, add an attention block in the mid-level of the encoder. | |
| """ | |
| super().__init__() | |
| self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) | |
| self.sample_rate = sample_rate | |
| self.mel_hop_length = mel_hop_length | |
| self.n_fft = n_fft | |
| self.is_causal = is_causal | |
| self.mel_bins = mel_bins | |
| self.patchifier = AudioPatchifier( | |
| patch_size=1, | |
| audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, | |
| sample_rate=sample_rate, | |
| hop_length=mel_hop_length, | |
| is_causal=is_causal, | |
| ) | |
| self.ch = ch | |
| self.temb_ch = 0 | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.in_channels = in_channels | |
| self.z_channels = z_channels | |
| self.double_z = double_z | |
| self.norm_type = norm_type | |
| self.causality_axis = causality_axis | |
| self.attn_type = attn_type | |
| # downsampling | |
| self.conv_in = make_conv2d( | |
| in_channels, | |
| self.ch, | |
| kernel_size=3, | |
| stride=1, | |
| causality_axis=self.causality_axis, | |
| ) | |
| self.non_linearity = torch.nn.SiLU() | |
| self.down, block_in = build_downsampling_path( | |
| ch=ch, | |
| ch_mult=ch_mult, | |
| num_resolutions=self.num_resolutions, | |
| num_res_blocks=num_res_blocks, | |
| resolution=resolution, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| norm_type=self.norm_type, | |
| causality_axis=self.causality_axis, | |
| attn_type=self.attn_type, | |
| attn_resolutions=attn_resolutions, | |
| resamp_with_conv=resamp_with_conv, | |
| ) | |
| self.mid = build_mid_block( | |
| channels=block_in, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| norm_type=self.norm_type, | |
| causality_axis=self.causality_axis, | |
| attn_type=self.attn_type, | |
| add_attention=mid_block_add_attention, | |
| ) | |
| self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) | |
| self.conv_out = make_conv2d( | |
| block_in, | |
| 2 * z_channels if double_z else z_channels, | |
| kernel_size=3, | |
| stride=1, | |
| causality_axis=self.causality_axis, | |
| ) | |
| def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Encode audio spectrogram into latent representations. | |
| Args: | |
| spectrogram: Input spectrogram of shape (batch, channels, time, frequency) | |
| Returns: | |
| Encoded latent representation of shape (batch, channels, frames, mel_bins) | |
| """ | |
| h = self.conv_in(spectrogram) | |
| h = self._run_downsampling_path(h) | |
| h = run_mid_block(self.mid, h) | |
| h = self._finalize_output(h) | |
| return self._normalize_latents(h) | |
| def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor: | |
| for level in range(self.num_resolutions): | |
| stage = self.down[level] | |
| for block_idx in range(self.num_res_blocks): | |
| h = stage.block[block_idx](h, temb=None) | |
| if stage.attn: | |
| h = stage.attn[block_idx](h) | |
| if level != self.num_resolutions - 1: | |
| h = stage.downsample(h) | |
| return h | |
| def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: | |
| h = self.norm_out(h) | |
| h = self.non_linearity(h) | |
| return self.conv_out(h) | |
| def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Normalize encoder latents using per-channel statistics. | |
| When the encoder is configured with ``double_z=True``, the final | |
| convolution produces twice the number of latent channels, typically | |
| interpreted as two concatenated tensors along the channel dimension | |
| (e.g., mean and variance or other auxiliary parameters). | |
| This method intentionally uses only the first half of the channels | |
| (the "mean" component) as input to the patchifier and normalization | |
| logic. The remaining channels are left unchanged by this method and | |
| are expected to be consumed elsewhere in the VAE pipeline. | |
| If ``double_z=False``, the encoder output already contains only the | |
| mean latents and the chunking operation simply returns that tensor. | |
| """ | |
| means = torch.chunk(latent_output, 2, dim=1)[0] | |
| latent_shape = AudioLatentShape( | |
| batch=means.shape[0], | |
| channels=means.shape[1], | |
| frames=means.shape[2], | |
| mel_bins=means.shape[3], | |
| ) | |
| latent_patched = self.patchifier.patchify(means) | |
| latent_normalized = self.per_channel_statistics.normalize(latent_patched) | |
| return self.patchifier.unpatchify(latent_normalized, latent_shape) | |
| class LTX2AudioDecoder(torch.nn.Module): | |
| """ | |
| Symmetric decoder that reconstructs audio spectrograms from latent features. | |
| The decoder mirrors the encoder structure with configurable channel multipliers, | |
| attention resolutions, and causal convolutions. | |
| """ | |
| def __init__( # noqa: PLR0913 | |
| self, | |
| *, | |
| ch: int = 128, | |
| out_ch: int = 2, | |
| ch_mult: Tuple[int, ...] = (1, 2, 4), | |
| num_res_blocks: int = 2, | |
| attn_resolutions: Set[int] = set(), | |
| resolution: int=256, | |
| z_channels: int=8, | |
| norm_type: NormType = NormType.PIXEL, | |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, | |
| dropout: float = 0.0, | |
| mid_block_add_attention: bool = False, | |
| sample_rate: int = 16000, | |
| mel_hop_length: int = 160, | |
| is_causal: bool = True, | |
| mel_bins: int | None = 64, | |
| ) -> None: | |
| """ | |
| Initialize the Decoder. | |
| Args: | |
| Arguments are configuration parameters, loaded from the audio VAE checkpoint config | |
| (audio_vae.model.params.ddconfig): | |
| - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions | |
| - resolution, z_channels | |
| - norm_type, causality_axis | |
| """ | |
| super().__init__() | |
| # Internal behavioural defaults that are not driven by the checkpoint. | |
| resamp_with_conv = True | |
| attn_type = AttentionType.VANILLA | |
| # Per-channel statistics for denormalizing latents | |
| self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) | |
| self.sample_rate = sample_rate | |
| self.mel_hop_length = mel_hop_length | |
| self.is_causal = is_causal | |
| self.mel_bins = mel_bins | |
| self.patchifier = AudioPatchifier( | |
| patch_size=1, | |
| audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, | |
| sample_rate=sample_rate, | |
| hop_length=mel_hop_length, | |
| is_causal=is_causal, | |
| ) | |
| self.ch = ch | |
| self.temb_ch = 0 | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.resolution = resolution | |
| self.out_ch = out_ch | |
| self.give_pre_end = False | |
| self.tanh_out = False | |
| self.norm_type = norm_type | |
| self.z_channels = z_channels | |
| self.channel_multipliers = ch_mult | |
| self.attn_resolutions = attn_resolutions | |
| self.causality_axis = causality_axis | |
| self.attn_type = attn_type | |
| base_block_channels = ch * self.channel_multipliers[-1] | |
| base_resolution = resolution // (2 ** (self.num_resolutions - 1)) | |
| self.z_shape = (1, z_channels, base_resolution, base_resolution) | |
| self.conv_in = make_conv2d( | |
| z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis | |
| ) | |
| self.non_linearity = torch.nn.SiLU() | |
| self.mid = build_mid_block( | |
| channels=base_block_channels, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| norm_type=self.norm_type, | |
| causality_axis=self.causality_axis, | |
| attn_type=self.attn_type, | |
| add_attention=mid_block_add_attention, | |
| ) | |
| self.up, final_block_channels = build_upsampling_path( | |
| ch=ch, | |
| ch_mult=ch_mult, | |
| num_resolutions=self.num_resolutions, | |
| num_res_blocks=num_res_blocks, | |
| resolution=resolution, | |
| temb_channels=self.temb_ch, | |
| dropout=dropout, | |
| norm_type=self.norm_type, | |
| causality_axis=self.causality_axis, | |
| attn_type=self.attn_type, | |
| attn_resolutions=attn_resolutions, | |
| resamp_with_conv=resamp_with_conv, | |
| initial_block_channels=base_block_channels, | |
| ) | |
| self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) | |
| self.conv_out = make_conv2d( | |
| final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis | |
| ) | |
| def forward(self, sample: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Decode latent features back to audio spectrograms. | |
| Args: | |
| sample: Encoded latent representation of shape (batch, channels, frames, mel_bins) | |
| Returns: | |
| Reconstructed audio spectrogram of shape (batch, channels, time, frequency) | |
| """ | |
| sample, target_shape = self._denormalize_latents(sample) | |
| h = self.conv_in(sample) | |
| h = run_mid_block(self.mid, h) | |
| h = self._run_upsampling_path(h) | |
| h = self._finalize_output(h) | |
| return self._adjust_output_shape(h, target_shape) | |
| def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]: | |
| latent_shape = AudioLatentShape( | |
| batch=sample.shape[0], | |
| channels=sample.shape[1], | |
| frames=sample.shape[2], | |
| mel_bins=sample.shape[3], | |
| ) | |
| sample_patched = self.patchifier.patchify(sample) | |
| sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) | |
| sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) | |
| target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR | |
| if self.causality_axis != CausalityAxis.NONE: | |
| target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) | |
| target_shape = AudioLatentShape( | |
| batch=latent_shape.batch, | |
| channels=self.out_ch, | |
| frames=target_frames, | |
| mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, | |
| ) | |
| return sample, target_shape | |
| def _adjust_output_shape( | |
| self, | |
| decoded_output: torch.Tensor, | |
| target_shape: AudioLatentShape, | |
| ) -> torch.Tensor: | |
| """ | |
| Adjust output shape to match target dimensions for variable-length audio. | |
| This function handles the common case where decoded audio spectrograms need to be | |
| resized to match a specific target shape. | |
| Args: | |
| decoded_output: Tensor of shape (batch, channels, time, frequency) | |
| target_shape: AudioLatentShape describing (batch, channels, time, mel bins) | |
| Returns: | |
| Tensor adjusted to match target_shape exactly | |
| """ | |
| # Current output shape: (batch, channels, time, frequency) | |
| _, _, current_time, current_freq = decoded_output.shape | |
| target_channels = target_shape.channels | |
| target_time = target_shape.frames | |
| target_freq = target_shape.mel_bins | |
| # Step 1: Crop first to avoid exceeding target dimensions | |
| decoded_output = decoded_output[ | |
| :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) | |
| ] | |
| # Step 2: Calculate padding needed for time and frequency dimensions | |
| time_padding_needed = target_time - decoded_output.shape[2] | |
| freq_padding_needed = target_freq - decoded_output.shape[3] | |
| # Step 3: Apply padding if needed | |
| if time_padding_needed > 0 or freq_padding_needed > 0: | |
| # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) | |
| # For audio: pad_left/right = frequency, pad_top/bottom = time | |
| padding = ( | |
| 0, | |
| max(freq_padding_needed, 0), # frequency padding (left, right) | |
| 0, | |
| max(time_padding_needed, 0), # time padding (top, bottom) | |
| ) | |
| decoded_output = F.pad(decoded_output, padding) | |
| # Step 4: Final safety crop to ensure exact target shape | |
| decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] | |
| return decoded_output | |
| def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor: | |
| for level in reversed(range(self.num_resolutions)): | |
| stage = self.up[level] | |
| for block_idx, block in enumerate(stage.block): | |
| h = block(h, temb=None) | |
| if stage.attn: | |
| h = stage.attn[block_idx](h) | |
| if level != 0 and hasattr(stage, "upsample"): | |
| h = stage.upsample(h) | |
| return h | |
| def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: | |
| if self.give_pre_end: | |
| return h | |
| h = self.norm_out(h) | |
| h = self.non_linearity(h) | |
| h = self.conv_out(h) | |
| return torch.tanh(h) if self.tanh_out else h | |
| def get_padding(kernel_size: int, dilation: int = 1) -> int: | |
| return int((kernel_size * dilation - dilation) / 2) | |
| # --------------------------------------------------------------------------- | |
| # Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 | |
| # Adopted from https://github.com/NVIDIA/BigVGAN | |
| # --------------------------------------------------------------------------- | |
| def _sinc(x: torch.Tensor) -> torch.Tensor: | |
| return torch.where( | |
| x == 0, | |
| torch.tensor(1.0, device=x.device, dtype=x.dtype), | |
| torch.sin(math.pi * x) / math.pi / x, | |
| ) | |
| def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: | |
| even = kernel_size % 2 == 0 | |
| half_size = kernel_size // 2 | |
| delta_f = 4 * half_width | |
| amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 | |
| if amplitude > 50.0: | |
| beta = 0.1102 * (amplitude - 8.7) | |
| elif amplitude >= 21.0: | |
| beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) | |
| else: | |
| beta = 0.0 | |
| window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) | |
| time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size | |
| if cutoff == 0: | |
| filter_ = torch.zeros_like(time) | |
| else: | |
| filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) | |
| filter_ /= filter_.sum() | |
| return filter_.view(1, 1, kernel_size) | |
| class LowPassFilter1d(nn.Module): | |
| def __init__( | |
| self, | |
| cutoff: float = 0.5, | |
| half_width: float = 0.6, | |
| stride: int = 1, | |
| padding: bool = True, | |
| padding_mode: str = "replicate", | |
| kernel_size: int = 12, | |
| ) -> None: | |
| super().__init__() | |
| if cutoff < -0.0: | |
| raise ValueError("Minimum cutoff must be larger than zero.") | |
| if cutoff > 0.5: | |
| raise ValueError("A cutoff above 0.5 does not make sense.") | |
| self.kernel_size = kernel_size | |
| self.even = kernel_size % 2 == 0 | |
| self.pad_left = kernel_size // 2 - int(self.even) | |
| self.pad_right = kernel_size // 2 | |
| self.stride = stride | |
| self.padding = padding | |
| self.padding_mode = padding_mode | |
| self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| _, n_channels, _ = x.shape | |
| if self.padding: | |
| x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) | |
| return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels) | |
| class UpSample1d(nn.Module): | |
| def __init__( | |
| self, | |
| ratio: int = 2, | |
| kernel_size: int | None = None, | |
| persistent: bool = True, | |
| window_type: str = "kaiser", | |
| ) -> None: | |
| super().__init__() | |
| self.ratio = ratio | |
| self.stride = ratio | |
| if window_type == "hann": | |
| # Hann-windowed sinc filter equivalent to torchaudio.functional.resample | |
| rolloff = 0.99 | |
| lowpass_filter_width = 6 | |
| width = math.ceil(lowpass_filter_width / rolloff) | |
| self.kernel_size = 2 * width * ratio + 1 | |
| self.pad = width | |
| self.pad_left = 2 * width * ratio | |
| self.pad_right = self.kernel_size - ratio | |
| time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff | |
| time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) | |
| window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 | |
| sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) | |
| else: | |
| # Kaiser-windowed sinc filter (BigVGAN default). | |
| self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size | |
| self.pad = self.kernel_size // ratio - 1 | |
| self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 | |
| self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 | |
| sinc_filter = kaiser_sinc_filter1d( | |
| cutoff=0.5 / ratio, | |
| half_width=0.6 / ratio, | |
| kernel_size=self.kernel_size, | |
| ) | |
| self.register_buffer("filter", sinc_filter, persistent=persistent) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| _, n_channels, _ = x.shape | |
| x = F.pad(x, (self.pad, self.pad), mode="replicate") | |
| filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1) | |
| x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels) | |
| return x[..., self.pad_left : -self.pad_right] | |
| class DownSample1d(nn.Module): | |
| def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None: | |
| super().__init__() | |
| self.ratio = ratio | |
| self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size | |
| self.lowpass = LowPassFilter1d( | |
| cutoff=0.5 / ratio, | |
| half_width=0.6 / ratio, | |
| stride=ratio, | |
| kernel_size=self.kernel_size, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.lowpass(x) | |
| class Activation1d(nn.Module): | |
| def __init__( | |
| self, | |
| activation: nn.Module, | |
| up_ratio: int = 2, | |
| down_ratio: int = 2, | |
| up_kernel_size: int = 12, | |
| down_kernel_size: int = 12, | |
| ) -> None: | |
| super().__init__() | |
| self.act = activation | |
| self.upsample = UpSample1d(up_ratio, up_kernel_size) | |
| self.downsample = DownSample1d(down_ratio, down_kernel_size) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.upsample(x) | |
| x = self.act(x) | |
| return self.downsample(x) | |
| class Snake(nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| alpha: float = 1.0, | |
| alpha_trainable: bool = True, | |
| alpha_logscale: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.alpha_logscale = alpha_logscale | |
| self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) | |
| self.alpha.requires_grad = alpha_trainable | |
| self.eps = 1e-9 | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| alpha = self.alpha.unsqueeze(0).unsqueeze(-1) | |
| if self.alpha_logscale: | |
| alpha = torch.exp(alpha) | |
| return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2) | |
| class SnakeBeta(nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| alpha: float = 1.0, | |
| alpha_trainable: bool = True, | |
| alpha_logscale: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.alpha_logscale = alpha_logscale | |
| self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) | |
| self.alpha.requires_grad = alpha_trainable | |
| self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) | |
| self.beta.requires_grad = alpha_trainable | |
| self.eps = 1e-9 | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| alpha = self.alpha.unsqueeze(0).unsqueeze(-1) | |
| beta = self.beta.unsqueeze(0).unsqueeze(-1) | |
| if self.alpha_logscale: | |
| alpha = torch.exp(alpha) | |
| beta = torch.exp(beta) | |
| return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2) | |
| class AMPBlock1(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| kernel_size: int = 3, | |
| dilation: tuple[int, int, int] = (1, 3, 5), | |
| activation: str = "snake", | |
| ) -> None: | |
| super().__init__() | |
| act_cls = SnakeBeta if activation == "snakebeta" else Snake | |
| self.convs1 = nn.ModuleList( | |
| [ | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[0], | |
| padding=get_padding(kernel_size, dilation[0]), | |
| ), | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[1], | |
| padding=get_padding(kernel_size, dilation[1]), | |
| ), | |
| nn.Conv1d( | |
| channels, | |
| channels, | |
| kernel_size, | |
| 1, | |
| dilation=dilation[2], | |
| padding=get_padding(kernel_size, dilation[2]), | |
| ), | |
| ] | |
| ) | |
| self.convs2 = nn.ModuleList( | |
| [ | |
| nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), | |
| nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), | |
| nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), | |
| ] | |
| ) | |
| self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]) | |
| self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True): | |
| xt = a1(x) | |
| xt = c1(xt) | |
| xt = a2(xt) | |
| xt = c2(xt) | |
| x = x + xt | |
| return x | |
| class LTX2Vocoder(torch.nn.Module): | |
| """ | |
| LTX2Vocoder model for synthesizing audio from Mel spectrograms. | |
| Args: | |
| resblock_kernel_sizes: List of kernel sizes for the residual blocks. | |
| This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`. | |
| upsample_rates: List of upsampling rates. | |
| This value is read from the checkpoint at `config.vocoder.upsample_rates`. | |
| upsample_kernel_sizes: List of kernel sizes for the upsampling layers. | |
| This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`. | |
| resblock_dilation_sizes: List of dilation sizes for the residual blocks. | |
| This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`. | |
| upsample_initial_channel: Initial number of channels for the upsampling layers. | |
| This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`. | |
| resblock: Type of residual block to use ("1", "2", or "AMP1"). | |
| This value is read from the checkpoint at `config.vocoder.resblock`. | |
| output_sampling_rate: Waveform sample rate. | |
| This value is read from the checkpoint at `config.vocoder.output_sampling_rate`. | |
| activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1". | |
| use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True). | |
| apply_final_activation: Whether to apply the final tanh/clamp activation. | |
| use_bias_at_final: Whether to use bias in the final conv layer. | |
| """ | |
| def __init__( # noqa: PLR0913 | |
| self, | |
| resblock_kernel_sizes: List[int] | None = [3, 7, 11], | |
| upsample_rates: List[int] | None = [6, 5, 2, 2, 2], | |
| upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4], | |
| resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| upsample_initial_channel: int = 1024, | |
| resblock: str = "1", | |
| output_sampling_rate: int = 24000, | |
| activation: str = "snake", | |
| use_tanh_at_final: bool = True, | |
| apply_final_activation: bool = True, | |
| use_bias_at_final: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| # Mutable default values are not supported as default arguments. | |
| if resblock_kernel_sizes is None: | |
| resblock_kernel_sizes = [3, 7, 11] | |
| if upsample_rates is None: | |
| upsample_rates = [6, 5, 2, 2, 2] | |
| if upsample_kernel_sizes is None: | |
| upsample_kernel_sizes = [16, 15, 8, 4, 4] | |
| if resblock_dilation_sizes is None: | |
| resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] | |
| self.output_sampling_rate = output_sampling_rate | |
| self.num_kernels = len(resblock_kernel_sizes) | |
| self.num_upsamples = len(upsample_rates) | |
| self.use_tanh_at_final = use_tanh_at_final | |
| self.apply_final_activation = apply_final_activation | |
| self.is_amp = resblock == "AMP1" | |
| # All production checkpoints are stereo: 128 input channels (2 stereo channels x 64 mel | |
| # bins each), 2 output channels. | |
| self.conv_pre = nn.Conv1d( | |
| in_channels=128, | |
| out_channels=upsample_initial_channel, | |
| kernel_size=7, | |
| stride=1, | |
| padding=3, | |
| ) | |
| resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1 | |
| self.ups = nn.ModuleList( | |
| nn.ConvTranspose1d( | |
| upsample_initial_channel // (2**i), | |
| upsample_initial_channel // (2 ** (i + 1)), | |
| kernel_size, | |
| stride, | |
| padding=(kernel_size - stride) // 2, | |
| ) | |
| for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)) | |
| ) | |
| final_channels = upsample_initial_channel // (2 ** len(upsample_rates)) | |
| self.resblocks = nn.ModuleList() | |
| for i in range(len(upsample_rates)): | |
| ch = upsample_initial_channel // (2 ** (i + 1)) | |
| for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True): | |
| if self.is_amp: | |
| self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation)) | |
| else: | |
| self.resblocks.append(resblock_cls(ch, kernel_size, dilations)) | |
| if self.is_amp: | |
| self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels)) | |
| else: | |
| self.act_post = nn.LeakyReLU() | |
| # All production checkpoints are stereo: this final conv maps `final_channels` to 2 output channels (stereo). | |
| self.conv_post = nn.Conv1d( | |
| in_channels=final_channels, | |
| out_channels=2, | |
| kernel_size=7, | |
| stride=1, | |
| padding=3, | |
| bias=use_bias_at_final, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass of the vocoder. | |
| Args: | |
| x: Input Mel spectrogram tensor. Can be either: | |
| - 3D: (batch_size, time, mel_bins) for mono | |
| - 4D: (batch_size, 2, time, mel_bins) for stereo | |
| Returns: | |
| Audio waveform tensor of shape (batch_size, out_channels, audio_length) | |
| """ | |
| x = x.transpose(2, 3) # (batch, channels, time, mel_bins) -> (batch, channels, mel_bins, time) | |
| if x.dim() == 4: # stereo | |
| assert x.shape[1] == 2, "Input must have 2 channels for stereo" | |
| x = einops.rearrange(x, "b s c t -> b (s c) t") | |
| x = self.conv_pre(x) | |
| for i in range(self.num_upsamples): | |
| if not self.is_amp: | |
| x = F.leaky_relu(x, LRELU_SLOPE) | |
| x = self.ups[i](x) | |
| start = i * self.num_kernels | |
| end = start + self.num_kernels | |
| # Evaluate all resblocks with the same input tensor so they can run | |
| # independently (and thus in parallel on accelerator hardware) before | |
| # aggregating their outputs via mean. | |
| block_outputs = torch.stack( | |
| [self.resblocks[idx](x) for idx in range(start, end)], | |
| dim=0, | |
| ) | |
| x = block_outputs.mean(dim=0) | |
| x = self.act_post(x) | |
| x = self.conv_post(x) | |
| if self.apply_final_activation: | |
| x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1) | |
| return x | |
| class _STFTFn(nn.Module): | |
| """Implements STFT as a convolution with precomputed DFT x Hann-window bases. | |
| The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal | |
| Hann window are stored as buffers and loaded from the checkpoint. Using the exact | |
| bfloat16 bases from training ensures the mel values fed to the BWE generator are | |
| bit-identical to what it was trained on. | |
| """ | |
| def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None: | |
| super().__init__() | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| n_freqs = filter_length // 2 + 1 | |
| self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) | |
| self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) | |
| def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Compute magnitude and phase spectrogram from a batch of waveforms. | |
| Applies causal (left-only) padding of win_length - hop_length samples so that | |
| each output frame depends only on past and present input — no lookahead. | |
| Args: | |
| y: Waveform tensor of shape (B, T). | |
| Returns: | |
| magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). | |
| phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). | |
| """ | |
| if y.dim() == 2: | |
| y = y.unsqueeze(1) # (B, 1, T) | |
| left_pad = max(0, self.win_length - self.hop_length) # causal: left-only | |
| y = F.pad(y, (left_pad, 0)) | |
| spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) | |
| n_freqs = spec.shape[1] // 2 | |
| real, imag = spec[:, :n_freqs], spec[:, n_freqs:] | |
| magnitude = torch.sqrt(real**2 + imag**2) | |
| phase = torch.atan2(imag.float(), real.float()).to(real.dtype) | |
| return magnitude, phase | |
| class MelSTFT(nn.Module): | |
| """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. | |
| Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input | |
| waveform and projecting the linear magnitude spectrum onto the mel filterbank. | |
| The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint | |
| (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). | |
| """ | |
| def __init__( | |
| self, | |
| filter_length: int, | |
| hop_length: int, | |
| win_length: int, | |
| n_mel_channels: int, | |
| ) -> None: | |
| super().__init__() | |
| self.stft_fn = _STFTFn(filter_length, hop_length, win_length) | |
| # Initialized to zeros; load_state_dict overwrites with the checkpoint's | |
| # exact bfloat16 filterbank (vocoder.mel_stft.mel_basis, shape [n_mels, n_freqs]). | |
| n_freqs = filter_length // 2 + 1 | |
| self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) | |
| def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Compute log-mel spectrogram and auxiliary spectral quantities. | |
| Args: | |
| y: Waveform tensor of shape (B, T). | |
| Returns: | |
| log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). | |
| magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). | |
| phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). | |
| energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). | |
| """ | |
| magnitude, phase = self.stft_fn(y) | |
| energy = torch.norm(magnitude, dim=1) | |
| mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) | |
| log_mel = torch.log(torch.clamp(mel, min=1e-5)) | |
| return log_mel, magnitude, phase, energy | |
| class LTX2VocoderWithBWE(nn.Module): | |
| """LTX2Vocoder with bandwidth extension (BWE) upsampling. | |
| Chains a mel-to-wav vocoder with a BWE module that upsamples the output | |
| to a higher sample rate. The BWE computes a mel spectrogram from the | |
| vocoder output, runs it through a second generator to predict a residual, | |
| and adds it to a sinc-resampled skip connection. | |
| """ | |
| def __init__( | |
| self, | |
| input_sampling_rate: int = 16000, | |
| output_sampling_rate: int = 48000, | |
| hop_length: int = 80, | |
| ) -> None: | |
| super().__init__() | |
| self.vocoder = LTX2Vocoder( | |
| resblock_kernel_sizes=[3, 7, 11], | |
| upsample_rates=[5, 2, 2, 2, 2, 2], | |
| upsample_kernel_sizes=[11, 4, 4, 4, 4, 4], | |
| resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| upsample_initial_channel=1536, | |
| resblock="AMP1", | |
| activation="snakebeta", | |
| use_tanh_at_final=False, | |
| apply_final_activation=True, | |
| use_bias_at_final=False, | |
| output_sampling_rate=input_sampling_rate, | |
| ) | |
| self.bwe_generator = LTX2Vocoder( | |
| resblock_kernel_sizes=[3, 7, 11], | |
| upsample_rates=[6, 5, 2, 2, 2], | |
| upsample_kernel_sizes=[12, 11, 4, 4, 4], | |
| resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
| upsample_initial_channel=512, | |
| resblock="AMP1", | |
| activation="snakebeta", | |
| use_tanh_at_final=False, | |
| apply_final_activation=False, | |
| use_bias_at_final=False, | |
| output_sampling_rate=output_sampling_rate, | |
| ) | |
| self.mel_stft = MelSTFT( | |
| filter_length=512, | |
| hop_length=hop_length, | |
| win_length=512, | |
| n_mel_channels=64, | |
| ) | |
| self.input_sampling_rate = input_sampling_rate | |
| self.output_sampling_rate = output_sampling_rate | |
| self.hop_length = hop_length | |
| # Compute the resampler on CPU so the sinc filter is materialized even when | |
| # the model is constructed on meta device (SingleGPUModelBuilder pattern). | |
| # The filter is not stored in the checkpoint (persistent=False). | |
| with torch.device("cpu"): | |
| self.resampler = UpSample1d( | |
| ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann" | |
| ) | |
| def conv_pre(self) -> nn.Conv1d: | |
| return self.vocoder.conv_pre | |
| def conv_post(self) -> nn.Conv1d: | |
| return self.vocoder.conv_post | |
| def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor: | |
| """Compute log-mel spectrogram from waveform using causal STFT bases. | |
| Args: | |
| audio: Waveform tensor of shape (B, C, T). | |
| Returns: | |
| mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames). | |
| """ | |
| batch, n_channels, _ = audio.shape | |
| flat = audio.reshape(batch * n_channels, -1) # (B*C, T) | |
| mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) | |
| return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) | |
| def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: | |
| """Run the full vocoder + BWE forward pass. | |
| Args: | |
| mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo | |
| or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward. | |
| Returns: | |
| Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1]. | |
| """ | |
| x = self.vocoder(mel_spec) | |
| _, _, length_low_rate = x.shape | |
| output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate | |
| # Pad to multiple of hop_length for exact mel frame count | |
| remainder = length_low_rate % self.hop_length | |
| if remainder != 0: | |
| x = F.pad(x, (0, self.hop_length - remainder)) | |
| # Compute mel spectrogram from vocoder output: (B, C, n_mels, T_frames) | |
| mel = self._compute_mel(x) | |
| # LTX2Vocoder.forward expects (B, C, T, mel_bins) — transpose before calling bwe_generator | |
| mel_for_bwe = mel.transpose(2, 3) # (B, C, T_frames, mel_bins) | |
| residual = self.bwe_generator(mel_for_bwe) | |
| skip = self.resampler(x) | |
| assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" | |
| return torch.clamp(residual + skip, -1, 1)[..., :output_length] | |