| from typing import Set, Tuple, Optional, List |
| from enum import Enum |
| import math |
| import einops |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torchaudio |
| from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer |
|
|
|
|
| class AudioProcessor(nn.Module): |
| """Converts audio waveforms to log-mel spectrograms with optional resampling.""" |
|
|
| def __init__( |
| self, |
| sample_rate: int = 16000, |
| mel_bins: int = 64, |
| mel_hop_length: int = 160, |
| n_fft: int = 1024, |
| ) -> None: |
| super().__init__() |
| self.sample_rate = sample_rate |
| self.mel_transform = torchaudio.transforms.MelSpectrogram( |
| sample_rate=sample_rate, |
| n_fft=n_fft, |
| win_length=n_fft, |
| hop_length=mel_hop_length, |
| f_min=0.0, |
| f_max=sample_rate / 2.0, |
| n_mels=mel_bins, |
| window_fn=torch.hann_window, |
| center=True, |
| pad_mode="reflect", |
| power=1.0, |
| mel_scale="slaney", |
| norm="slaney", |
| ) |
|
|
| def resample_waveform( |
| self, |
| waveform: torch.Tensor, |
| source_rate: int, |
| target_rate: int, |
| ) -> torch.Tensor: |
| """Resample waveform to target sample rate if needed.""" |
| if source_rate == target_rate: |
| return waveform |
| resampled = torchaudio.functional.resample(waveform, source_rate, target_rate) |
| return resampled.to(device=waveform.device, dtype=waveform.dtype) |
|
|
| def waveform_to_mel( |
| self, |
| waveform: torch.Tensor, |
| waveform_sample_rate: int, |
| ) -> torch.Tensor: |
| """Convert waveform to log-mel spectrogram [batch, channels, time, n_mels].""" |
| waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate) |
|
|
| mel = self.mel_transform(waveform) |
| mel = torch.log(torch.clamp(mel, min=1e-5)) |
|
|
| mel = mel.to(device=waveform.device, dtype=waveform.dtype) |
| return mel.permute(0, 1, 3, 2).contiguous() |
|
|
|
|
| class AudioPatchifier(Patchifier): |
| def __init__( |
| self, |
| patch_size: int, |
| sample_rate: int = 16000, |
| hop_length: int = 160, |
| audio_latent_downsample_factor: int = 4, |
| is_causal: bool = True, |
| shift: int = 0, |
| ): |
| """ |
| Patchifier tailored for spectrogram/audio latents. |
| Args: |
| patch_size: Number of mel bins combined into a single patch. This |
| controls the resolution along the frequency axis. |
| sample_rate: Original waveform sampling rate. Used to map latent |
| indices back to seconds so downstream consumers can align audio |
| and video cues. |
| hop_length: Window hop length used for the spectrogram. Determines |
| how many real-time samples separate two consecutive latent frames. |
| audio_latent_downsample_factor: Ratio between spectrogram frames and |
| latent frames; compensates for additional downsampling inside the |
| VAE encoder. |
| is_causal: When True, timing is shifted to account for causal |
| receptive fields so timestamps do not peek into the future. |
| shift: Integer offset applied to the latent indices. Enables |
| constructing overlapping windows from the same latent sequence. |
| """ |
| self.hop_length = hop_length |
| self.sample_rate = sample_rate |
| self.audio_latent_downsample_factor = audio_latent_downsample_factor |
| self.is_causal = is_causal |
| self.shift = shift |
| self._patch_size = (1, patch_size, patch_size) |
|
|
| @property |
| def patch_size(self) -> Tuple[int, int, int]: |
| return self._patch_size |
|
|
| def get_token_count(self, tgt_shape: AudioLatentShape) -> int: |
| return tgt_shape.frames |
|
|
| def _get_audio_latent_time_in_sec( |
| self, |
| start_latent: int, |
| end_latent: int, |
| dtype: torch.dtype, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Converts latent indices into real-time seconds while honoring causal |
| offsets and the configured hop length. |
| Args: |
| start_latent: Inclusive start index inside the latent sequence. This |
| sets the first timestamp returned. |
| end_latent: Exclusive end index. Determines how many timestamps get |
| generated. |
| dtype: Floating-point dtype used for the returned tensor, allowing |
| callers to control precision. |
| device: Target device for the timestamp tensor. When omitted the |
| computation occurs on CPU to avoid surprising GPU allocations. |
| """ |
| if device is None: |
| device = torch.device("cpu") |
|
|
| audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) |
|
|
| audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor |
|
|
| if self.is_causal: |
| |
| |
| causal_offset = 1 |
| audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0) |
|
|
| return audio_mel_frame * self.hop_length / self.sample_rate |
|
|
| def _compute_audio_timings( |
| self, |
| batch_size: int, |
| num_steps: int, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Builds a `(B, 1, T, 2)` tensor containing timestamps for each latent frame. |
| This helper method underpins `get_patch_grid_bounds` for the audio patchifier. |
| Args: |
| batch_size: Number of sequences to broadcast the timings over. |
| num_steps: Number of latent frames (time steps) to convert into timestamps. |
| device: Device on which the resulting tensor should reside. |
| """ |
| resolved_device = device |
| if resolved_device is None: |
| resolved_device = torch.device("cpu") |
|
|
| start_timings = self._get_audio_latent_time_in_sec( |
| self.shift, |
| num_steps + self.shift, |
| torch.float32, |
| resolved_device, |
| ) |
| start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) |
|
|
| end_timings = self._get_audio_latent_time_in_sec( |
| self.shift + 1, |
| num_steps + self.shift + 1, |
| torch.float32, |
| resolved_device, |
| ) |
| end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1) |
|
|
| return torch.stack([start_timings, end_timings], dim=-1) |
|
|
| def patchify( |
| self, |
| audio_latents: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Flattens the audio latent tensor along time. Use `get_patch_grid_bounds` |
| to derive timestamps for each latent frame based on the configured hop |
| length and downsampling. |
| Args: |
| audio_latents: Latent tensor to patchify. |
| Returns: |
| Flattened patch tokens tensor. Use `get_patch_grid_bounds` to compute the |
| corresponding timing metadata when needed. |
| """ |
| audio_latents = einops.rearrange( |
| audio_latents, |
| "b c t f -> b t (c f)", |
| ) |
|
|
| return audio_latents |
|
|
| def unpatchify( |
| self, |
| audio_latents: torch.Tensor, |
| output_shape: AudioLatentShape, |
| ) -> torch.Tensor: |
| """ |
| Restores the `(B, C, T, F)` spectrogram tensor from flattened patches. |
| Use `get_patch_grid_bounds` to recompute the timestamps that describe each |
| frame's position in real time. |
| Args: |
| audio_latents: Latent tensor to unpatchify. |
| output_shape: Shape of the unpatched output tensor. |
| Returns: |
| Unpatched latent tensor. Use `get_patch_grid_bounds` to compute the timing |
| metadata associated with the restored latents. |
| """ |
| |
| audio_latents = einops.rearrange( |
| audio_latents, |
| "b t (c f) -> b c t f", |
| c=output_shape.channels, |
| f=output_shape.mel_bins, |
| ) |
|
|
| return audio_latents |
|
|
| def unpatchify_audio( |
| self, |
| audio_latents: torch.Tensor, |
| channels: int, |
| mel_bins: int |
| ) -> torch.Tensor: |
| audio_latents = einops.rearrange( |
| audio_latents, |
| "b t (c f) -> b c t f", |
| c=channels, |
| f=mel_bins, |
| ) |
| return audio_latents |
|
|
| def get_patch_grid_bounds( |
| self, |
| output_shape: AudioLatentShape | VideoLatentShape, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Return the temporal bounds `[inclusive start, exclusive end)` for every |
| patch emitted by `patchify`. For audio this corresponds to timestamps in |
| seconds aligned with the original spectrogram grid. |
| The returned tensor has shape `[batch_size, 1, time_steps, 2]`, where: |
| - axis 1 (size 1) represents the temporal dimension |
| - axis 3 (size 2) stores the `[start, end)` timestamps per patch |
| Args: |
| output_shape: Audio grid specification describing the number of time steps. |
| device: Target device for the returned tensor. |
| """ |
| if not isinstance(output_shape, AudioLatentShape): |
| raise ValueError("AudioPatchifier expects AudioLatentShape when computing coordinates") |
|
|
| return self._compute_audio_timings(output_shape.batch, output_shape.frames, device) |
|
|
|
|
| class AttentionType(Enum): |
| """Enum for specifying the attention mechanism type.""" |
|
|
| VANILLA = "vanilla" |
| LINEAR = "linear" |
| NONE = "none" |
|
|
|
|
| class AttnBlock(torch.nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| norm_type: NormType = NormType.GROUP, |
| ) -> None: |
| super().__init__() |
| self.in_channels = in_channels |
|
|
| self.norm = build_normalization_layer(in_channels, normtype=norm_type) |
| self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
| self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
| self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
| self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h_ = x |
| h_ = self.norm(h_) |
| q = self.q(h_) |
| k = self.k(h_) |
| v = self.v(h_) |
|
|
| |
| b, c, h, w = q.shape |
| q = q.reshape(b, c, h * w).contiguous() |
| q = q.permute(0, 2, 1).contiguous() |
| k = k.reshape(b, c, h * w).contiguous() |
| w_ = torch.bmm(q, k).contiguous() |
| w_ = w_ * (int(c) ** (-0.5)) |
| w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
| |
| v = v.reshape(b, c, h * w).contiguous() |
| w_ = w_.permute(0, 2, 1).contiguous() |
| h_ = torch.bmm(v, w_).contiguous() |
| h_ = h_.reshape(b, c, h, w).contiguous() |
|
|
| h_ = self.proj_out(h_) |
|
|
| return x + h_ |
|
|
|
|
| def make_attn( |
| in_channels: int, |
| attn_type: AttentionType = AttentionType.VANILLA, |
| norm_type: NormType = NormType.GROUP, |
| ) -> torch.nn.Module: |
| match attn_type: |
| case AttentionType.VANILLA: |
| return AttnBlock(in_channels, norm_type=norm_type) |
| case AttentionType.NONE: |
| return torch.nn.Identity() |
| case AttentionType.LINEAR: |
| raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") |
| case _: |
| raise ValueError(f"Unknown attention type: {attn_type}") |
|
|
|
|
| class CausalityAxis(Enum): |
| """Enum for specifying the causality axis in causal convolutions.""" |
|
|
| NONE = None |
| WIDTH = "width" |
| HEIGHT = "height" |
| WIDTH_COMPATIBILITY = "width-compatibility" |
|
|
|
|
| class CausalConv2d(torch.nn.Module): |
| """ |
| A causal 2D convolution. |
| This layer ensures that the output at time `t` only depends on inputs |
| at time `t` and earlier. It achieves this by applying asymmetric padding |
| to the time dimension (width) before the convolution. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int | tuple[int, int], |
| stride: int = 1, |
| dilation: int | tuple[int, int] = 1, |
| groups: int = 1, |
| bias: bool = True, |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, |
| ) -> None: |
| super().__init__() |
|
|
| self.causality_axis = causality_axis |
|
|
| |
| kernel_size = torch.nn.modules.utils._pair(kernel_size) |
| dilation = torch.nn.modules.utils._pair(dilation) |
|
|
| |
| pad_h = (kernel_size[0] - 1) * dilation[0] |
| pad_w = (kernel_size[1] - 1) * dilation[1] |
|
|
| |
| match self.causality_axis: |
| case CausalityAxis.NONE: |
| self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) |
| case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: |
| self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) |
| case CausalityAxis.HEIGHT: |
| self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) |
| case _: |
| raise ValueError(f"Invalid causality_axis: {causality_axis}") |
|
|
| |
| self.conv = torch.nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=stride, |
| padding=0, |
| dilation=dilation, |
| groups=groups, |
| bias=bias, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x = F.pad(x, self.padding) |
| return self.conv(x) |
|
|
|
|
| def make_conv2d( |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int | tuple[int, int], |
| stride: int = 1, |
| padding: tuple[int, int, int, int] | None = None, |
| dilation: int = 1, |
| groups: int = 1, |
| bias: bool = True, |
| causality_axis: CausalityAxis | None = None, |
| ) -> torch.nn.Module: |
| """ |
| Create a 2D convolution layer that can be either causal or non-causal. |
| Args: |
| in_channels: Number of input channels |
| out_channels: Number of output channels |
| kernel_size: Size of the convolution kernel |
| stride: Convolution stride |
| padding: Padding (if None, will be calculated based on causal flag) |
| dilation: Dilation rate |
| groups: Number of groups for grouped convolution |
| bias: Whether to use bias |
| causality_axis: Dimension along which to apply causality. |
| Returns: |
| Either a regular Conv2d or CausalConv2d layer |
| """ |
| if causality_axis is not None: |
| |
| return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) |
| else: |
| |
| if padding is None: |
| padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size) |
|
|
| return torch.nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride, |
| padding, |
| dilation, |
| groups, |
| bias, |
| ) |
|
|
|
|
|
|
| LRELU_SLOPE = 0.1 |
|
|
|
|
| class ResBlock1(torch.nn.Module): |
| def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int, int] = (1, 3, 5)): |
| super(ResBlock1, self).__init__() |
| self.convs1 = torch.nn.ModuleList( |
| [ |
| torch.nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[0], |
| padding="same", |
| ), |
| torch.nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[1], |
| padding="same", |
| ), |
| torch.nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[2], |
| padding="same", |
| ), |
| ] |
| ) |
|
|
| self.convs2 = torch.nn.ModuleList( |
| [ |
| torch.nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=1, |
| padding="same", |
| ), |
| torch.nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=1, |
| padding="same", |
| ), |
| torch.nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=1, |
| padding="same", |
| ), |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for conv1, conv2 in zip(self.convs1, self.convs2, strict=True): |
| xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) |
| xt = conv1(xt) |
| xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) |
| xt = conv2(xt) |
| x = xt + x |
| return x |
|
|
|
|
| class ResBlock2(torch.nn.Module): |
| def __init__(self, channels: int, kernel_size: int = 3, dilation: Tuple[int, int] = (1, 3)): |
| super(ResBlock2, self).__init__() |
| self.convs = torch.nn.ModuleList( |
| [ |
| torch.nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[0], |
| padding="same", |
| ), |
| torch.nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[1], |
| padding="same", |
| ), |
| ] |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for conv in self.convs: |
| xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) |
| xt = conv(xt) |
| x = xt + x |
| return x |
|
|
|
|
| class ResnetBlock(torch.nn.Module): |
| def __init__( |
| self, |
| *, |
| in_channels: int, |
| out_channels: int | None = None, |
| conv_shortcut: bool = False, |
| dropout: float = 0.0, |
| temb_channels: int = 512, |
| norm_type: NormType = NormType.GROUP, |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, |
| ) -> None: |
| super().__init__() |
| self.causality_axis = causality_axis |
|
|
| if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP: |
| raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") |
| self.in_channels = in_channels |
| out_channels = in_channels if out_channels is None else out_channels |
| self.out_channels = out_channels |
| self.use_conv_shortcut = conv_shortcut |
|
|
| self.norm1 = build_normalization_layer(in_channels, normtype=norm_type) |
| self.non_linearity = torch.nn.SiLU() |
| self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) |
| if temb_channels > 0: |
| self.temb_proj = torch.nn.Linear(temb_channels, out_channels) |
| self.norm2 = build_normalization_layer(out_channels, normtype=norm_type) |
| self.dropout = torch.nn.Dropout(dropout) |
| self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) |
| if self.in_channels != self.out_channels: |
| if self.use_conv_shortcut: |
| self.conv_shortcut = make_conv2d( |
| in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis |
| ) |
| else: |
| self.nin_shortcut = make_conv2d( |
| in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| temb: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| h = x |
| h = self.norm1(h) |
| h = self.non_linearity(h) |
| h = self.conv1(h) |
|
|
| if temb is not None: |
| h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] |
|
|
| h = self.norm2(h) |
| h = self.non_linearity(h) |
| h = self.dropout(h) |
| h = self.conv2(h) |
|
|
| if self.in_channels != self.out_channels: |
| x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) |
|
|
| return x + h |
|
|
|
|
| class Downsample(torch.nn.Module): |
| """ |
| A downsampling layer that can use either a strided convolution |
| or average pooling. Supports standard and causal padding for the |
| convolutional mode. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| with_conv: bool, |
| causality_axis: CausalityAxis = CausalityAxis.WIDTH, |
| ) -> None: |
| super().__init__() |
| self.with_conv = with_conv |
| self.causality_axis = causality_axis |
|
|
| if self.causality_axis != CausalityAxis.NONE and not self.with_conv: |
| raise ValueError("causality is only supported when `with_conv=True`.") |
|
|
| if self.with_conv: |
| |
| |
| self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.with_conv: |
| |
| match self.causality_axis: |
| case CausalityAxis.NONE: |
| pad = (0, 1, 0, 1) |
| case CausalityAxis.WIDTH: |
| pad = (2, 0, 0, 1) |
| case CausalityAxis.HEIGHT: |
| pad = (0, 1, 2, 0) |
| case CausalityAxis.WIDTH_COMPATIBILITY: |
| pad = (1, 0, 0, 1) |
| case _: |
| raise ValueError(f"Invalid causality_axis: {self.causality_axis}") |
|
|
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| x = self.conv(x) |
| else: |
| |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
|
|
| return x |
|
|
|
|
| def build_downsampling_path( |
| *, |
| ch: int, |
| ch_mult: Tuple[int, ...], |
| num_resolutions: int, |
| num_res_blocks: int, |
| resolution: int, |
| temb_channels: int, |
| dropout: float, |
| norm_type: NormType, |
| causality_axis: CausalityAxis, |
| attn_type: AttentionType, |
| attn_resolutions: Set[int], |
| resamp_with_conv: bool, |
| ) -> tuple[torch.nn.ModuleList, int]: |
| """Build the downsampling path with residual blocks, attention, and downsampling layers.""" |
| down_modules = torch.nn.ModuleList() |
| curr_res = resolution |
| in_ch_mult = (1, *tuple(ch_mult)) |
| block_in = ch |
|
|
| for i_level in range(num_resolutions): |
| block = torch.nn.ModuleList() |
| attn = torch.nn.ModuleList() |
| block_in = ch * in_ch_mult[i_level] |
| block_out = ch * ch_mult[i_level] |
|
|
| for _ in range(num_res_blocks): |
| block.append( |
| ResnetBlock( |
| in_channels=block_in, |
| out_channels=block_out, |
| temb_channels=temb_channels, |
| dropout=dropout, |
| norm_type=norm_type, |
| causality_axis=causality_axis, |
| ) |
| ) |
| block_in = block_out |
| if curr_res in attn_resolutions: |
| attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) |
|
|
| down = torch.nn.Module() |
| down.block = block |
| down.attn = attn |
| if i_level != num_resolutions - 1: |
| down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) |
| curr_res = curr_res // 2 |
| down_modules.append(down) |
|
|
| return down_modules, block_in |
|
|
|
|
| class Upsample(torch.nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| with_conv: bool, |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, |
| ) -> None: |
| super().__init__() |
| self.with_conv = with_conv |
| self.causality_axis = causality_axis |
| if self.with_conv: |
| self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
| if self.with_conv: |
| x = self.conv(x) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| match self.causality_axis: |
| case CausalityAxis.NONE: |
| pass |
| case CausalityAxis.HEIGHT: |
| x = x[:, :, 1:, :] |
| case CausalityAxis.WIDTH: |
| x = x[:, :, :, 1:] |
| case CausalityAxis.WIDTH_COMPATIBILITY: |
| pass |
| case _: |
| raise ValueError(f"Invalid causality_axis: {self.causality_axis}") |
|
|
| return x |
|
|
|
|
| def build_upsampling_path( |
| *, |
| ch: int, |
| ch_mult: Tuple[int, ...], |
| num_resolutions: int, |
| num_res_blocks: int, |
| resolution: int, |
| temb_channels: int, |
| dropout: float, |
| norm_type: NormType, |
| causality_axis: CausalityAxis, |
| attn_type: AttentionType, |
| attn_resolutions: Set[int], |
| resamp_with_conv: bool, |
| initial_block_channels: int, |
| ) -> tuple[torch.nn.ModuleList, int]: |
| """Build the upsampling path with residual blocks, attention, and upsampling layers.""" |
| up_modules = torch.nn.ModuleList() |
| block_in = initial_block_channels |
| curr_res = resolution // (2 ** (num_resolutions - 1)) |
|
|
| for level in reversed(range(num_resolutions)): |
| stage = torch.nn.Module() |
| stage.block = torch.nn.ModuleList() |
| stage.attn = torch.nn.ModuleList() |
| block_out = ch * ch_mult[level] |
|
|
| for _ in range(num_res_blocks + 1): |
| stage.block.append( |
| ResnetBlock( |
| in_channels=block_in, |
| out_channels=block_out, |
| temb_channels=temb_channels, |
| dropout=dropout, |
| norm_type=norm_type, |
| causality_axis=causality_axis, |
| ) |
| ) |
| block_in = block_out |
| if curr_res in attn_resolutions: |
| stage.attn.append(make_attn(block_in, attn_type=attn_type, norm_type=norm_type)) |
|
|
| if level != 0: |
| stage.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) |
| curr_res *= 2 |
|
|
| up_modules.insert(0, stage) |
|
|
| return up_modules, block_in |
|
|
|
|
| class PerChannelStatistics(nn.Module): |
| """ |
| Per-channel statistics for normalizing and denormalizing the latent representation. |
| This statics is computed over the entire dataset and stored in model's checkpoint under AudioVAE state_dict. |
| """ |
|
|
| def __init__(self, latent_channels: int = 128) -> None: |
| super().__init__() |
| self.register_buffer("std-of-means", torch.empty(latent_channels)) |
| self.register_buffer("mean-of-means", torch.empty(latent_channels)) |
|
|
| def un_normalize(self, x: torch.Tensor) -> torch.Tensor: |
| return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) |
|
|
| def normalize(self, x: torch.Tensor) -> torch.Tensor: |
| return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) |
|
|
|
|
| LATENT_DOWNSAMPLE_FACTOR = 4 |
|
|
|
|
| def build_mid_block( |
| channels: int, |
| temb_channels: int, |
| dropout: float, |
| norm_type: NormType, |
| causality_axis: CausalityAxis, |
| attn_type: AttentionType, |
| add_attention: bool, |
| ) -> torch.nn.Module: |
| """Build the middle block with two ResNet blocks and optional attention.""" |
| mid = torch.nn.Module() |
| mid.block_1 = ResnetBlock( |
| in_channels=channels, |
| out_channels=channels, |
| temb_channels=temb_channels, |
| dropout=dropout, |
| norm_type=norm_type, |
| causality_axis=causality_axis, |
| ) |
| mid.attn_1 = make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else torch.nn.Identity() |
| mid.block_2 = ResnetBlock( |
| in_channels=channels, |
| out_channels=channels, |
| temb_channels=temb_channels, |
| dropout=dropout, |
| norm_type=norm_type, |
| causality_axis=causality_axis, |
| ) |
| return mid |
|
|
|
|
| def run_mid_block(mid: torch.nn.Module, features: torch.Tensor) -> torch.Tensor: |
| """Run features through the middle block.""" |
| features = mid.block_1(features, temb=None) |
| features = mid.attn_1(features) |
| return mid.block_2(features, temb=None) |
|
|
|
|
| class LTX2AudioEncoder(torch.nn.Module): |
| """ |
| Encoder that compresses audio spectrograms into latent representations. |
| The encoder uses a series of downsampling blocks with residual connections, |
| attention mechanisms, and configurable causal convolutions. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| ch: int = 128, |
| ch_mult: Tuple[int, ...] = (1, 2, 4), |
| num_res_blocks: int = 2, |
| attn_resolutions: Set[int] = set(), |
| dropout: float = 0.0, |
| resamp_with_conv: bool = True, |
| in_channels: int = 2, |
| resolution: int = 256, |
| z_channels: int = 8, |
| double_z: bool = True, |
| attn_type: AttentionType = AttentionType.VANILLA, |
| mid_block_add_attention: bool = False, |
| norm_type: NormType = NormType.PIXEL, |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, |
| sample_rate: int = 16000, |
| mel_hop_length: int = 160, |
| n_fft: int = 1024, |
| is_causal: bool = True, |
| mel_bins: int = 64, |
| **_ignore_kwargs, |
| ) -> None: |
| """ |
| Initialize the Encoder. |
| Args: |
| Arguments are configuration parameters, loaded from the audio VAE checkpoint config |
| (audio_vae.model.params.ddconfig): |
| ch: Base number of feature channels used in the first convolution layer. |
| ch_mult: Multiplicative factors for the number of channels at each resolution level. |
| num_res_blocks: Number of residual blocks to use at each resolution level. |
| attn_resolutions: Spatial resolutions (e.g., in time/frequency) at which to apply attention. |
| resolution: Input spatial resolution of the spectrogram (height, width). |
| z_channels: Number of channels in the latent representation. |
| norm_type: Normalization layer type to use within the network (e.g., group, batch). |
| causality_axis: Axis along which convolutions should be causal (e.g., time axis). |
| sample_rate: Audio sample rate in Hz for the input signals. |
| mel_hop_length: Hop length used when computing the mel spectrogram. |
| n_fft: FFT size used to compute the spectrogram. |
| mel_bins: Number of mel-frequency bins in the input spectrogram. |
| in_channels: Number of channels in the input spectrogram tensor. |
| double_z: If True, predict both mean and log-variance (doubling latent channels). |
| is_causal: If True, use causal convolutions suitable for streaming setups. |
| dropout: Dropout probability used in residual and mid blocks. |
| attn_type: Type of attention mechanism to use in attention blocks. |
| resamp_with_conv: If True, perform resolution changes using strided convolutions. |
| mid_block_add_attention: If True, add an attention block in the mid-level of the encoder. |
| """ |
| super().__init__() |
|
|
| self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) |
| self.sample_rate = sample_rate |
| self.mel_hop_length = mel_hop_length |
| self.n_fft = n_fft |
| self.is_causal = is_causal |
| self.mel_bins = mel_bins |
|
|
| self.patchifier = AudioPatchifier( |
| patch_size=1, |
| audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, |
| sample_rate=sample_rate, |
| hop_length=mel_hop_length, |
| is_causal=is_causal, |
| ) |
|
|
| self.ch = ch |
| self.temb_ch = 0 |
| self.num_resolutions = len(ch_mult) |
| self.num_res_blocks = num_res_blocks |
| self.resolution = resolution |
| self.in_channels = in_channels |
| self.z_channels = z_channels |
| self.double_z = double_z |
| self.norm_type = norm_type |
| self.causality_axis = causality_axis |
| self.attn_type = attn_type |
|
|
| |
| self.conv_in = make_conv2d( |
| in_channels, |
| self.ch, |
| kernel_size=3, |
| stride=1, |
| causality_axis=self.causality_axis, |
| ) |
|
|
| self.non_linearity = torch.nn.SiLU() |
|
|
| self.down, block_in = build_downsampling_path( |
| ch=ch, |
| ch_mult=ch_mult, |
| num_resolutions=self.num_resolutions, |
| num_res_blocks=num_res_blocks, |
| resolution=resolution, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| norm_type=self.norm_type, |
| causality_axis=self.causality_axis, |
| attn_type=self.attn_type, |
| attn_resolutions=attn_resolutions, |
| resamp_with_conv=resamp_with_conv, |
| ) |
|
|
| self.mid = build_mid_block( |
| channels=block_in, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| norm_type=self.norm_type, |
| causality_axis=self.causality_axis, |
| attn_type=self.attn_type, |
| add_attention=mid_block_add_attention, |
| ) |
|
|
| self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type) |
| self.conv_out = make_conv2d( |
| block_in, |
| 2 * z_channels if double_z else z_channels, |
| kernel_size=3, |
| stride=1, |
| causality_axis=self.causality_axis, |
| ) |
|
|
| def forward(self, spectrogram: torch.Tensor) -> torch.Tensor: |
| """ |
| Encode audio spectrogram into latent representations. |
| Args: |
| spectrogram: Input spectrogram of shape (batch, channels, time, frequency) |
| Returns: |
| Encoded latent representation of shape (batch, channels, frames, mel_bins) |
| """ |
| h = self.conv_in(spectrogram) |
| h = self._run_downsampling_path(h) |
| h = run_mid_block(self.mid, h) |
| h = self._finalize_output(h) |
|
|
| return self._normalize_latents(h) |
|
|
| def _run_downsampling_path(self, h: torch.Tensor) -> torch.Tensor: |
| for level in range(self.num_resolutions): |
| stage = self.down[level] |
| for block_idx in range(self.num_res_blocks): |
| h = stage.block[block_idx](h, temb=None) |
| if stage.attn: |
| h = stage.attn[block_idx](h) |
|
|
| if level != self.num_resolutions - 1: |
| h = stage.downsample(h) |
|
|
| return h |
|
|
| def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: |
| h = self.norm_out(h) |
| h = self.non_linearity(h) |
| return self.conv_out(h) |
|
|
| def _normalize_latents(self, latent_output: torch.Tensor) -> torch.Tensor: |
| """ |
| Normalize encoder latents using per-channel statistics. |
| When the encoder is configured with ``double_z=True``, the final |
| convolution produces twice the number of latent channels, typically |
| interpreted as two concatenated tensors along the channel dimension |
| (e.g., mean and variance or other auxiliary parameters). |
| This method intentionally uses only the first half of the channels |
| (the "mean" component) as input to the patchifier and normalization |
| logic. The remaining channels are left unchanged by this method and |
| are expected to be consumed elsewhere in the VAE pipeline. |
| If ``double_z=False``, the encoder output already contains only the |
| mean latents and the chunking operation simply returns that tensor. |
| """ |
| means = torch.chunk(latent_output, 2, dim=1)[0] |
| latent_shape = AudioLatentShape( |
| batch=means.shape[0], |
| channels=means.shape[1], |
| frames=means.shape[2], |
| mel_bins=means.shape[3], |
| ) |
| latent_patched = self.patchifier.patchify(means) |
| latent_normalized = self.per_channel_statistics.normalize(latent_patched) |
| return self.patchifier.unpatchify(latent_normalized, latent_shape) |
|
|
|
|
| class LTX2AudioDecoder(torch.nn.Module): |
| """ |
| Symmetric decoder that reconstructs audio spectrograms from latent features. |
| The decoder mirrors the encoder structure with configurable channel multipliers, |
| attention resolutions, and causal convolutions. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| ch: int = 128, |
| out_ch: int = 2, |
| ch_mult: Tuple[int, ...] = (1, 2, 4), |
| num_res_blocks: int = 2, |
| attn_resolutions: Set[int] = set(), |
| resolution: int=256, |
| z_channels: int=8, |
| norm_type: NormType = NormType.PIXEL, |
| causality_axis: CausalityAxis = CausalityAxis.HEIGHT, |
| dropout: float = 0.0, |
| mid_block_add_attention: bool = False, |
| sample_rate: int = 16000, |
| mel_hop_length: int = 160, |
| is_causal: bool = True, |
| mel_bins: int | None = 64, |
| ) -> None: |
| """ |
| Initialize the Decoder. |
| Args: |
| Arguments are configuration parameters, loaded from the audio VAE checkpoint config |
| (audio_vae.model.params.ddconfig): |
| - ch, out_ch, ch_mult, num_res_blocks, attn_resolutions |
| - resolution, z_channels |
| - norm_type, causality_axis |
| """ |
| super().__init__() |
|
|
| |
| resamp_with_conv = True |
| attn_type = AttentionType.VANILLA |
|
|
| |
| self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) |
| self.sample_rate = sample_rate |
| self.mel_hop_length = mel_hop_length |
| self.is_causal = is_causal |
| self.mel_bins = mel_bins |
| self.patchifier = AudioPatchifier( |
| patch_size=1, |
| audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, |
| sample_rate=sample_rate, |
| hop_length=mel_hop_length, |
| is_causal=is_causal, |
| ) |
|
|
| self.ch = ch |
| self.temb_ch = 0 |
| self.num_resolutions = len(ch_mult) |
| self.num_res_blocks = num_res_blocks |
| self.resolution = resolution |
| self.out_ch = out_ch |
| self.give_pre_end = False |
| self.tanh_out = False |
| self.norm_type = norm_type |
| self.z_channels = z_channels |
| self.channel_multipliers = ch_mult |
| self.attn_resolutions = attn_resolutions |
| self.causality_axis = causality_axis |
| self.attn_type = attn_type |
|
|
| base_block_channels = ch * self.channel_multipliers[-1] |
| base_resolution = resolution // (2 ** (self.num_resolutions - 1)) |
| self.z_shape = (1, z_channels, base_resolution, base_resolution) |
|
|
| self.conv_in = make_conv2d( |
| z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis |
| ) |
| self.non_linearity = torch.nn.SiLU() |
| self.mid = build_mid_block( |
| channels=base_block_channels, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| norm_type=self.norm_type, |
| causality_axis=self.causality_axis, |
| attn_type=self.attn_type, |
| add_attention=mid_block_add_attention, |
| ) |
| self.up, final_block_channels = build_upsampling_path( |
| ch=ch, |
| ch_mult=ch_mult, |
| num_resolutions=self.num_resolutions, |
| num_res_blocks=num_res_blocks, |
| resolution=resolution, |
| temb_channels=self.temb_ch, |
| dropout=dropout, |
| norm_type=self.norm_type, |
| causality_axis=self.causality_axis, |
| attn_type=self.attn_type, |
| attn_resolutions=attn_resolutions, |
| resamp_with_conv=resamp_with_conv, |
| initial_block_channels=base_block_channels, |
| ) |
|
|
| self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) |
| self.conv_out = make_conv2d( |
| final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis |
| ) |
|
|
| def forward(self, sample: torch.Tensor) -> torch.Tensor: |
| """ |
| Decode latent features back to audio spectrograms. |
| Args: |
| sample: Encoded latent representation of shape (batch, channels, frames, mel_bins) |
| Returns: |
| Reconstructed audio spectrogram of shape (batch, channels, time, frequency) |
| """ |
| sample, target_shape = self._denormalize_latents(sample) |
|
|
| h = self.conv_in(sample) |
| h = run_mid_block(self.mid, h) |
| h = self._run_upsampling_path(h) |
| h = self._finalize_output(h) |
|
|
| return self._adjust_output_shape(h, target_shape) |
|
|
| def _denormalize_latents(self, sample: torch.Tensor) -> tuple[torch.Tensor, AudioLatentShape]: |
| latent_shape = AudioLatentShape( |
| batch=sample.shape[0], |
| channels=sample.shape[1], |
| frames=sample.shape[2], |
| mel_bins=sample.shape[3], |
| ) |
|
|
| sample_patched = self.patchifier.patchify(sample) |
| sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched) |
| sample = self.patchifier.unpatchify(sample_denormalized, latent_shape) |
|
|
| target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR |
| if self.causality_axis != CausalityAxis.NONE: |
| target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) |
|
|
| target_shape = AudioLatentShape( |
| batch=latent_shape.batch, |
| channels=self.out_ch, |
| frames=target_frames, |
| mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins, |
| ) |
|
|
| return sample, target_shape |
|
|
| def _adjust_output_shape( |
| self, |
| decoded_output: torch.Tensor, |
| target_shape: AudioLatentShape, |
| ) -> torch.Tensor: |
| """ |
| Adjust output shape to match target dimensions for variable-length audio. |
| This function handles the common case where decoded audio spectrograms need to be |
| resized to match a specific target shape. |
| Args: |
| decoded_output: Tensor of shape (batch, channels, time, frequency) |
| target_shape: AudioLatentShape describing (batch, channels, time, mel bins) |
| Returns: |
| Tensor adjusted to match target_shape exactly |
| """ |
| |
| _, _, current_time, current_freq = decoded_output.shape |
| target_channels = target_shape.channels |
| target_time = target_shape.frames |
| target_freq = target_shape.mel_bins |
|
|
| |
| decoded_output = decoded_output[ |
| :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) |
| ] |
|
|
| |
| time_padding_needed = target_time - decoded_output.shape[2] |
| freq_padding_needed = target_freq - decoded_output.shape[3] |
|
|
| |
| if time_padding_needed > 0 or freq_padding_needed > 0: |
| |
| |
| padding = ( |
| 0, |
| max(freq_padding_needed, 0), |
| 0, |
| max(time_padding_needed, 0), |
| ) |
| decoded_output = F.pad(decoded_output, padding) |
|
|
| |
| decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] |
|
|
| return decoded_output |
|
|
| def _run_upsampling_path(self, h: torch.Tensor) -> torch.Tensor: |
| for level in reversed(range(self.num_resolutions)): |
| stage = self.up[level] |
| for block_idx, block in enumerate(stage.block): |
| h = block(h, temb=None) |
| if stage.attn: |
| h = stage.attn[block_idx](h) |
|
|
| if level != 0 and hasattr(stage, "upsample"): |
| h = stage.upsample(h) |
|
|
| return h |
|
|
| def _finalize_output(self, h: torch.Tensor) -> torch.Tensor: |
| if self.give_pre_end: |
| return h |
|
|
| h = self.norm_out(h) |
| h = self.non_linearity(h) |
| h = self.conv_out(h) |
| return torch.tanh(h) if self.tanh_out else h |
|
|
|
|
| def get_padding(kernel_size: int, dilation: int = 1) -> int: |
| return int((kernel_size * dilation - dilation) / 2) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| def _sinc(x: torch.Tensor) -> torch.Tensor: |
| return torch.where( |
| x == 0, |
| torch.tensor(1.0, device=x.device, dtype=x.dtype), |
| torch.sin(math.pi * x) / math.pi / x, |
| ) |
|
|
|
|
| def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor: |
| even = kernel_size % 2 == 0 |
| half_size = kernel_size // 2 |
| delta_f = 4 * half_width |
| amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 |
| if amplitude > 50.0: |
| beta = 0.1102 * (amplitude - 8.7) |
| elif amplitude >= 21.0: |
| beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0) |
| else: |
| beta = 0.0 |
| window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) |
| time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size |
| if cutoff == 0: |
| filter_ = torch.zeros_like(time) |
| else: |
| filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) |
| filter_ /= filter_.sum() |
| return filter_.view(1, 1, kernel_size) |
|
|
|
|
| class LowPassFilter1d(nn.Module): |
| def __init__( |
| self, |
| cutoff: float = 0.5, |
| half_width: float = 0.6, |
| stride: int = 1, |
| padding: bool = True, |
| padding_mode: str = "replicate", |
| kernel_size: int = 12, |
| ) -> None: |
| super().__init__() |
| if cutoff < -0.0: |
| raise ValueError("Minimum cutoff must be larger than zero.") |
| if cutoff > 0.5: |
| raise ValueError("A cutoff above 0.5 does not make sense.") |
| self.kernel_size = kernel_size |
| self.even = kernel_size % 2 == 0 |
| self.pad_left = kernel_size // 2 - int(self.even) |
| self.pad_right = kernel_size // 2 |
| self.stride = stride |
| self.padding = padding |
| self.padding_mode = padding_mode |
| self.register_buffer("filter", kaiser_sinc_filter1d(cutoff, half_width, kernel_size)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| _, n_channels, _ = x.shape |
| if self.padding: |
| x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) |
| return F.conv1d(x, self.filter.expand(n_channels, -1, -1), stride=self.stride, groups=n_channels) |
|
|
|
|
| class UpSample1d(nn.Module): |
| def __init__( |
| self, |
| ratio: int = 2, |
| kernel_size: int | None = None, |
| persistent: bool = True, |
| window_type: str = "kaiser", |
| ) -> None: |
| super().__init__() |
| self.ratio = ratio |
| self.stride = ratio |
|
|
| if window_type == "hann": |
| |
| rolloff = 0.99 |
| lowpass_filter_width = 6 |
| width = math.ceil(lowpass_filter_width / rolloff) |
| self.kernel_size = 2 * width * ratio + 1 |
| self.pad = width |
| self.pad_left = 2 * width * ratio |
| self.pad_right = self.kernel_size - ratio |
| time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff |
| time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width) |
| window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2 |
| sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1) |
| else: |
| |
| self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size |
| self.pad = self.kernel_size // ratio - 1 |
| self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 |
| self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 |
| sinc_filter = kaiser_sinc_filter1d( |
| cutoff=0.5 / ratio, |
| half_width=0.6 / ratio, |
| kernel_size=self.kernel_size, |
| ) |
|
|
| self.register_buffer("filter", sinc_filter, persistent=persistent) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| _, n_channels, _ = x.shape |
| x = F.pad(x, (self.pad, self.pad), mode="replicate") |
| filt = self.filter.to(dtype=x.dtype, device=x.device).expand(n_channels, -1, -1) |
| x = self.ratio * F.conv_transpose1d(x, filt, stride=self.stride, groups=n_channels) |
| return x[..., self.pad_left : -self.pad_right] |
|
|
|
|
| class DownSample1d(nn.Module): |
| def __init__(self, ratio: int = 2, kernel_size: int | None = None) -> None: |
| super().__init__() |
| self.ratio = ratio |
| self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size |
| self.lowpass = LowPassFilter1d( |
| cutoff=0.5 / ratio, |
| half_width=0.6 / ratio, |
| stride=ratio, |
| kernel_size=self.kernel_size, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.lowpass(x) |
|
|
|
|
| class Activation1d(nn.Module): |
| def __init__( |
| self, |
| activation: nn.Module, |
| up_ratio: int = 2, |
| down_ratio: int = 2, |
| up_kernel_size: int = 12, |
| down_kernel_size: int = 12, |
| ) -> None: |
| super().__init__() |
| self.act = activation |
| self.upsample = UpSample1d(up_ratio, up_kernel_size) |
| self.downsample = DownSample1d(down_ratio, down_kernel_size) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.upsample(x) |
| x = self.act(x) |
| return self.downsample(x) |
|
|
|
|
| class Snake(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| alpha: float = 1.0, |
| alpha_trainable: bool = True, |
| alpha_logscale: bool = True, |
| ) -> None: |
| super().__init__() |
| self.alpha_logscale = alpha_logscale |
| self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) |
| self.alpha.requires_grad = alpha_trainable |
| self.eps = 1e-9 |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| alpha = self.alpha.unsqueeze(0).unsqueeze(-1) |
| if self.alpha_logscale: |
| alpha = torch.exp(alpha) |
| return x + (1.0 / (alpha + self.eps)) * torch.sin(x * alpha).pow(2) |
|
|
|
|
| class SnakeBeta(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| alpha: float = 1.0, |
| alpha_trainable: bool = True, |
| alpha_logscale: bool = True, |
| ) -> None: |
| super().__init__() |
| self.alpha_logscale = alpha_logscale |
| self.alpha = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) |
| self.alpha.requires_grad = alpha_trainable |
| self.beta = nn.Parameter(torch.zeros(in_features) if alpha_logscale else torch.ones(in_features) * alpha) |
| self.beta.requires_grad = alpha_trainable |
| self.eps = 1e-9 |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| alpha = self.alpha.unsqueeze(0).unsqueeze(-1) |
| beta = self.beta.unsqueeze(0).unsqueeze(-1) |
| if self.alpha_logscale: |
| alpha = torch.exp(alpha) |
| beta = torch.exp(beta) |
| return x + (1.0 / (beta + self.eps)) * torch.sin(x * alpha).pow(2) |
|
|
|
|
| class AMPBlock1(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| kernel_size: int = 3, |
| dilation: tuple[int, int, int] = (1, 3, 5), |
| activation: str = "snake", |
| ) -> None: |
| super().__init__() |
| act_cls = SnakeBeta if activation == "snakebeta" else Snake |
| self.convs1 = nn.ModuleList( |
| [ |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[0], |
| padding=get_padding(kernel_size, dilation[0]), |
| ), |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[1], |
| padding=get_padding(kernel_size, dilation[1]), |
| ), |
| nn.Conv1d( |
| channels, |
| channels, |
| kernel_size, |
| 1, |
| dilation=dilation[2], |
| padding=get_padding(kernel_size, dilation[2]), |
| ), |
| ] |
| ) |
|
|
| self.convs2 = nn.ModuleList( |
| [ |
| nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), |
| nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), |
| nn.Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)), |
| ] |
| ) |
|
|
| self.acts1 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]) |
| self.acts2 = nn.ModuleList([Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2, strict=True): |
| xt = a1(x) |
| xt = c1(xt) |
| xt = a2(xt) |
| xt = c2(xt) |
| x = x + xt |
| return x |
|
|
|
|
| class LTX2Vocoder(torch.nn.Module): |
| """ |
| LTX2Vocoder model for synthesizing audio from Mel spectrograms. |
| Args: |
| resblock_kernel_sizes: List of kernel sizes for the residual blocks. |
| This value is read from the checkpoint at `config.vocoder.resblock_kernel_sizes`. |
| upsample_rates: List of upsampling rates. |
| This value is read from the checkpoint at `config.vocoder.upsample_rates`. |
| upsample_kernel_sizes: List of kernel sizes for the upsampling layers. |
| This value is read from the checkpoint at `config.vocoder.upsample_kernel_sizes`. |
| resblock_dilation_sizes: List of dilation sizes for the residual blocks. |
| This value is read from the checkpoint at `config.vocoder.resblock_dilation_sizes`. |
| upsample_initial_channel: Initial number of channels for the upsampling layers. |
| This value is read from the checkpoint at `config.vocoder.upsample_initial_channel`. |
| resblock: Type of residual block to use ("1", "2", or "AMP1"). |
| This value is read from the checkpoint at `config.vocoder.resblock`. |
| output_sampling_rate: Waveform sample rate. |
| This value is read from the checkpoint at `config.vocoder.output_sampling_rate`. |
| activation: Activation type for BigVGAN v2 ("snake" or "snakebeta"). Only used when resblock="AMP1". |
| use_tanh_at_final: Apply tanh at the output (when apply_final_activation=True). |
| apply_final_activation: Whether to apply the final tanh/clamp activation. |
| use_bias_at_final: Whether to use bias in the final conv layer. |
| """ |
|
|
| def __init__( |
| self, |
| resblock_kernel_sizes: List[int] | None = [3, 7, 11], |
| upsample_rates: List[int] | None = [6, 5, 2, 2, 2], |
| upsample_kernel_sizes: List[int] | None = [16, 15, 8, 4, 4], |
| resblock_dilation_sizes: List[List[int]] | None = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
| upsample_initial_channel: int = 1024, |
| resblock: str = "1", |
| output_sampling_rate: int = 24000, |
| activation: str = "snake", |
| use_tanh_at_final: bool = True, |
| apply_final_activation: bool = True, |
| use_bias_at_final: bool = True, |
| ) -> None: |
| super().__init__() |
|
|
| |
| if resblock_kernel_sizes is None: |
| resblock_kernel_sizes = [3, 7, 11] |
| if upsample_rates is None: |
| upsample_rates = [6, 5, 2, 2, 2] |
| if upsample_kernel_sizes is None: |
| upsample_kernel_sizes = [16, 15, 8, 4, 4] |
| if resblock_dilation_sizes is None: |
| resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] |
|
|
| self.output_sampling_rate = output_sampling_rate |
| self.num_kernels = len(resblock_kernel_sizes) |
| self.num_upsamples = len(upsample_rates) |
| self.use_tanh_at_final = use_tanh_at_final |
| self.apply_final_activation = apply_final_activation |
| self.is_amp = resblock == "AMP1" |
|
|
| |
| |
| self.conv_pre = nn.Conv1d( |
| in_channels=128, |
| out_channels=upsample_initial_channel, |
| kernel_size=7, |
| stride=1, |
| padding=3, |
| ) |
| resblock_cls = ResBlock1 if resblock == "1" else AMPBlock1 |
|
|
| self.ups = nn.ModuleList( |
| nn.ConvTranspose1d( |
| upsample_initial_channel // (2**i), |
| upsample_initial_channel // (2 ** (i + 1)), |
| kernel_size, |
| stride, |
| padding=(kernel_size - stride) // 2, |
| ) |
| for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes, strict=True)) |
| ) |
|
|
| final_channels = upsample_initial_channel // (2 ** len(upsample_rates)) |
| self.resblocks = nn.ModuleList() |
|
|
| for i in range(len(upsample_rates)): |
| ch = upsample_initial_channel // (2 ** (i + 1)) |
| for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes, strict=True): |
| if self.is_amp: |
| self.resblocks.append(resblock_cls(ch, kernel_size, dilations, activation=activation)) |
| else: |
| self.resblocks.append(resblock_cls(ch, kernel_size, dilations)) |
|
|
| if self.is_amp: |
| self.act_post: nn.Module = Activation1d(SnakeBeta(final_channels)) |
| else: |
| self.act_post = nn.LeakyReLU() |
|
|
| |
| self.conv_post = nn.Conv1d( |
| in_channels=final_channels, |
| out_channels=2, |
| kernel_size=7, |
| stride=1, |
| padding=3, |
| bias=use_bias_at_final, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass of the vocoder. |
| Args: |
| x: Input Mel spectrogram tensor. Can be either: |
| - 3D: (batch_size, time, mel_bins) for mono |
| - 4D: (batch_size, 2, time, mel_bins) for stereo |
| Returns: |
| Audio waveform tensor of shape (batch_size, out_channels, audio_length) |
| """ |
| x = x.transpose(2, 3) |
|
|
| if x.dim() == 4: |
| assert x.shape[1] == 2, "Input must have 2 channels for stereo" |
| x = einops.rearrange(x, "b s c t -> b (s c) t") |
|
|
| x = self.conv_pre(x) |
|
|
| for i in range(self.num_upsamples): |
| if not self.is_amp: |
| x = F.leaky_relu(x, LRELU_SLOPE) |
| x = self.ups[i](x) |
| start = i * self.num_kernels |
| end = start + self.num_kernels |
|
|
| |
| |
| |
| block_outputs = torch.stack( |
| [self.resblocks[idx](x) for idx in range(start, end)], |
| dim=0, |
| ) |
| x = block_outputs.mean(dim=0) |
|
|
| x = self.act_post(x) |
| x = self.conv_post(x) |
|
|
| if self.apply_final_activation: |
| x = torch.tanh(x) if self.use_tanh_at_final else torch.clamp(x, -1, 1) |
|
|
| return x |
|
|
|
|
| class _STFTFn(nn.Module): |
| """Implements STFT as a convolution with precomputed DFT x Hann-window bases. |
| The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal |
| Hann window are stored as buffers and loaded from the checkpoint. Using the exact |
| bfloat16 bases from training ensures the mel values fed to the BWE generator are |
| bit-identical to what it was trained on. |
| """ |
|
|
| def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None: |
| super().__init__() |
| self.hop_length = hop_length |
| self.win_length = win_length |
| n_freqs = filter_length // 2 + 1 |
| self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) |
| self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) |
|
|
| def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """Compute magnitude and phase spectrogram from a batch of waveforms. |
| Applies causal (left-only) padding of win_length - hop_length samples so that |
| each output frame depends only on past and present input — no lookahead. |
| Args: |
| y: Waveform tensor of shape (B, T). |
| Returns: |
| magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). |
| phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). |
| """ |
| if y.dim() == 2: |
| y = y.unsqueeze(1) |
| left_pad = max(0, self.win_length - self.hop_length) |
| y = F.pad(y, (left_pad, 0)) |
| spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) |
| n_freqs = spec.shape[1] // 2 |
| real, imag = spec[:, :n_freqs], spec[:, n_freqs:] |
| magnitude = torch.sqrt(real**2 + imag**2) |
| phase = torch.atan2(imag.float(), real.float()).to(real.dtype) |
| return magnitude, phase |
|
|
|
|
| class MelSTFT(nn.Module): |
| """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. |
| Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input |
| waveform and projecting the linear magnitude spectrum onto the mel filterbank. |
| The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint |
| (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). |
| """ |
|
|
| def __init__( |
| self, |
| filter_length: int, |
| hop_length: int, |
| win_length: int, |
| n_mel_channels: int, |
| ) -> None: |
| super().__init__() |
| self.stft_fn = _STFTFn(filter_length, hop_length, win_length) |
|
|
| |
| |
| n_freqs = filter_length // 2 + 1 |
| self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) |
|
|
| def mel_spectrogram(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Compute log-mel spectrogram and auxiliary spectral quantities. |
| Args: |
| y: Waveform tensor of shape (B, T). |
| Returns: |
| log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). |
| magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). |
| phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). |
| energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). |
| """ |
| magnitude, phase = self.stft_fn(y) |
| energy = torch.norm(magnitude, dim=1) |
| mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) |
| log_mel = torch.log(torch.clamp(mel, min=1e-5)) |
| return log_mel, magnitude, phase, energy |
|
|
|
|
| class LTX2VocoderWithBWE(nn.Module): |
| """LTX2Vocoder with bandwidth extension (BWE) upsampling. |
| Chains a mel-to-wav vocoder with a BWE module that upsamples the output |
| to a higher sample rate. The BWE computes a mel spectrogram from the |
| vocoder output, runs it through a second generator to predict a residual, |
| and adds it to a sinc-resampled skip connection. |
| """ |
|
|
| def __init__( |
| self, |
| input_sampling_rate: int = 16000, |
| output_sampling_rate: int = 48000, |
| hop_length: int = 80, |
| ) -> None: |
| super().__init__() |
| self.vocoder = LTX2Vocoder( |
| resblock_kernel_sizes=[3, 7, 11], |
| upsample_rates=[5, 2, 2, 2, 2, 2], |
| upsample_kernel_sizes=[11, 4, 4, 4, 4, 4], |
| resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
| upsample_initial_channel=1536, |
| resblock="AMP1", |
| activation="snakebeta", |
| use_tanh_at_final=False, |
| apply_final_activation=True, |
| use_bias_at_final=False, |
| output_sampling_rate=input_sampling_rate, |
| ) |
| self.bwe_generator = LTX2Vocoder( |
| resblock_kernel_sizes=[3, 7, 11], |
| upsample_rates=[6, 5, 2, 2, 2], |
| upsample_kernel_sizes=[12, 11, 4, 4, 4], |
| resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
| upsample_initial_channel=512, |
| resblock="AMP1", |
| activation="snakebeta", |
| use_tanh_at_final=False, |
| apply_final_activation=False, |
| use_bias_at_final=False, |
| output_sampling_rate=output_sampling_rate, |
| ) |
| |
| self.mel_stft = MelSTFT( |
| filter_length=512, |
| hop_length=hop_length, |
| win_length=512, |
| n_mel_channels=64, |
| ) |
| self.input_sampling_rate = input_sampling_rate |
| self.output_sampling_rate = output_sampling_rate |
| self.hop_length = hop_length |
| |
| |
| |
| with torch.device("cpu"): |
| self.resampler = UpSample1d( |
| ratio=output_sampling_rate // input_sampling_rate, persistent=False, window_type="hann" |
| ) |
|
|
| @property |
| def conv_pre(self) -> nn.Conv1d: |
| return self.vocoder.conv_pre |
|
|
| @property |
| def conv_post(self) -> nn.Conv1d: |
| return self.vocoder.conv_post |
|
|
| def _compute_mel(self, audio: torch.Tensor) -> torch.Tensor: |
| """Compute log-mel spectrogram from waveform using causal STFT bases. |
| Args: |
| audio: Waveform tensor of shape (B, C, T). |
| Returns: |
| mel: Log-mel spectrogram of shape (B, C, n_mels, T_frames). |
| """ |
| batch, n_channels, _ = audio.shape |
| flat = audio.reshape(batch * n_channels, -1) |
| mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) |
| return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2]) |
|
|
| def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: |
| """Run the full vocoder + BWE forward pass. |
| Args: |
| mel_spec: Mel spectrogram of shape (B, 2, T, mel_bins) for stereo |
| or (B, T, mel_bins) for mono. Same format as LTX2Vocoder.forward. |
| Returns: |
| Waveform tensor of shape (B, out_channels, T_out) clipped to [-1, 1]. |
| """ |
| x = self.vocoder(mel_spec) |
| _, _, length_low_rate = x.shape |
| output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate |
|
|
| |
| remainder = length_low_rate % self.hop_length |
| if remainder != 0: |
| x = F.pad(x, (0, self.hop_length - remainder)) |
|
|
| |
| mel = self._compute_mel(x) |
|
|
| |
| mel_for_bwe = mel.transpose(2, 3) |
| residual = self.bwe_generator(mel_for_bwe) |
| skip = self.resampler(x) |
| assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" |
|
|
| return torch.clamp(residual + skip, -1, 1)[..., :output_length] |
|
|