learnable-speech / flowae /models /networks /consistency_audio_decoder_unet.py
primepake
add training flowvae
4f877a2
# https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac
import torch
import torch.nn.functional as F
import torch.nn as nn
from models import register
class PositionalEmbedding(nn.Module):
def __init__(self, pe_dim=320, out_dim=1280, max_positions=10000, endpoint=True):
super().__init__()
self.num_channels = pe_dim
self.max_positions = max_positions
self.endpoint = endpoint
self.f_1 = nn.Linear(pe_dim, out_dim)
self.f_2 = nn.Linear(out_dim, out_dim)
def forward(self, x):
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
freqs = (1 / self.max_positions) ** freqs
x = x.ger(freqs.to(x.dtype))
x = torch.cat([x.cos(), x.sin()], dim=1)
x = self.f_1(x)
x = F.silu(x)
return self.f_2(x)
class AudioEmbedding(nn.Module):
"""1D convolution for audio input embedding"""
def __init__(self, in_channels, out_channels=320, kernel_size=3) -> None:
super().__init__()
self.f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
def forward(self, x) -> torch.Tensor:
return self.f(x)
class AudioUnembedding(nn.Module):
"""1D convolution for audio output"""
def __init__(self, in_channels=320, out_channels=1, kernel_size=3) -> None:
super().__init__()
self.gn = nn.GroupNorm(32, in_channels)
self.f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
def forward(self, x) -> torch.Tensor:
return self.f(F.silu(self.gn(x)))
class AudioConvResblock(nn.Module):
"""1D Residual block for audio"""
def __init__(self, in_features, out_features, t_dim, kernel_size=3) -> None:
super().__init__()
self.f_t = nn.Linear(t_dim, out_features * 2)
self.gn_1 = nn.GroupNorm(32, in_features)
self.f_1 = nn.Conv1d(in_features, out_features, kernel_size=kernel_size, padding=kernel_size//2)
self.gn_2 = nn.GroupNorm(32, out_features)
self.f_2 = nn.Conv1d(out_features, out_features, kernel_size=kernel_size, padding=kernel_size//2)
skip_conv = in_features != out_features
self.f_s = (
nn.Conv1d(in_features, out_features, kernel_size=1, padding=0)
if skip_conv
else nn.Identity()
)
def forward(self, x, t):
x_skip = x
t = self.f_t(F.silu(t))
t = t.chunk(2, dim=1)
t_1 = t[0].unsqueeze(dim=2) + 1 # [batch, channels, 1]
t_2 = t[1].unsqueeze(dim=2) # [batch, channels, 1]
gn_1 = F.silu(self.gn_1(x))
f_1 = self.f_1(gn_1)
gn_2 = self.gn_2(f_1)
return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
class AudioDownsample(nn.Module):
"""1D downsampling for audio"""
def __init__(self, in_channels, t_dim, downsample_factor=2) -> None:
super().__init__()
self.f_t = nn.Linear(t_dim, in_channels * 2)
self.downsample_factor = downsample_factor
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2) + 1
t_2 = t_2.unsqueeze(2)
gn_1 = F.silu(self.gn_1(x))
# 1D average pooling
avg_pool1d = F.avg_pool1d(gn_1, kernel_size=self.downsample_factor)
f_1 = self.f_1(avg_pool1d)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.avg_pool1d(x_skip, kernel_size=self.downsample_factor)
class AudioUpsample(nn.Module):
"""1D upsampling for audio"""
def __init__(self, in_channels, t_dim, upsample_factor=2) -> None:
super().__init__()
self.f_t = nn.Linear(t_dim, in_channels * 2)
self.upsample_factor = upsample_factor
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2) + 1
t_2 = t_2.unsqueeze(2)
gn_1 = F.silu(self.gn_1(x))
# 1D interpolation upsampling
upsample = F.interpolate(gn_1, scale_factor=self.upsample_factor, mode='linear')
f_1 = self.f_1(upsample)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.interpolate(x_skip, scale_factor=self.upsample_factor, mode='linear')
@register('audio_diffusion_unet')
class AudioDiffusionUNet(nn.Module):
"""
1D UNet for audio diffusion with dynamic latent conditioning
Handles:
- x: [batch, 1, samples] - audio waveform (dynamic length)
- z_dec: [batch, 64, n_frames] - latent conditioning (dynamic length)
"""
def __init__(
self,
in_channels=1, # Audio channels (mono=1, stereo=2)
z_dec_channels=64, # Latent conditioning channels
c0=128, c1=256, c2=512, # Channel progression (smaller than image version)
pe_dim=320,
t_dim=1280,
kernel_size=3
) -> None:
super().__init__()
# Store for dynamic conditioning
self.z_dec_channels = z_dec_channels
# Audio input embedding
self.embed_audio = AudioEmbedding(
in_channels=in_channels,
out_channels=c0,
kernel_size=kernel_size
)
# Time embedding
self.embed_time = PositionalEmbedding(pe_dim=pe_dim, out_dim=t_dim)
# Latent conditioning projection
if z_dec_channels is not None:
self.z_dec_proj = nn.Conv1d(z_dec_channels, c0, kernel_size=1)
# Downsampling path
down_0 = nn.ModuleList([
AudioConvResblock(c0, c0, t_dim, kernel_size),
AudioConvResblock(c0, c0, t_dim, kernel_size),
AudioConvResblock(c0, c0, t_dim, kernel_size),
AudioDownsample(c0, t_dim),
])
down_1 = nn.ModuleList([
AudioConvResblock(c0, c1, t_dim, kernel_size),
AudioConvResblock(c1, c1, t_dim, kernel_size),
AudioConvResblock(c1, c1, t_dim, kernel_size),
AudioDownsample(c1, t_dim),
])
down_2 = nn.ModuleList([
AudioConvResblock(c1, c2, t_dim, kernel_size),
AudioConvResblock(c2, c2, t_dim, kernel_size),
AudioConvResblock(c2, c2, t_dim, kernel_size),
AudioDownsample(c2, t_dim),
])
down_3 = nn.ModuleList([
AudioConvResblock(c2, c2, t_dim, kernel_size),
AudioConvResblock(c2, c2, t_dim, kernel_size),
AudioConvResblock(c2, c2, t_dim, kernel_size),
])
self.down = nn.ModuleList([down_0, down_1, down_2, down_3])
# Middle layers
self.mid = nn.ModuleList([
AudioConvResblock(c2, c2, t_dim, kernel_size),
AudioConvResblock(c2, c2, t_dim, kernel_size),
])
# Upsampling path
up_3 = nn.ModuleList([
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
AudioUpsample(c2, t_dim),
])
up_2 = nn.ModuleList([
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
AudioConvResblock(c2 * 2, c2, t_dim, kernel_size),
AudioConvResblock(c2 + c1, c2, t_dim, kernel_size),
AudioUpsample(c2, t_dim),
])
up_1 = nn.ModuleList([
AudioConvResblock(c2 + c1, c1, t_dim, kernel_size),
AudioConvResblock(c1 * 2, c1, t_dim, kernel_size),
AudioConvResblock(c1 * 2, c1, t_dim, kernel_size),
AudioConvResblock(c0 + c1, c1, t_dim, kernel_size),
AudioUpsample(c1, t_dim),
])
up_0 = nn.ModuleList([
AudioConvResblock(c0 + c1, c0, t_dim, kernel_size),
AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
AudioConvResblock(c0 * 2, c0, t_dim, kernel_size),
])
self.up = nn.ModuleList([up_0, up_1, up_2, up_3])
# Output layer
self.output = AudioUnembedding(in_channels=c0, out_channels=in_channels)
def get_last_layer_weight(self):
return self.output.f.weight
def condition_with_latents(self, x, z_dec):
"""
Add latent conditioning to audio features
Args:
x: [batch, c0, audio_samples] - audio features
z_dec: [batch, 64, n_frames] - latent conditioning
Returns:
x: [batch, c0, audio_samples] - conditioned features
"""
if z_dec is None:
return x
# Project latents to same channel dimension as audio features
z_proj = self.z_dec_proj(z_dec) # [batch, c0, n_frames]
# Interpolate latents to match audio length
if z_proj.shape[-1] != x.shape[-1]:
z_proj = F.interpolate(
z_proj,
size=x.shape[-1],
mode='linear' # or 'linear' for smoother interpolation
)
print('shape of z_proj: ', z_proj.shape)
# Add latent conditioning to audio features
return torch.cat([x, z_proj], dim=1)
def forward(self, x, t=None, z_dec=None) -> torch.Tensor:
"""
Forward pass
Args:
x: [batch, 1, samples] - audio waveform (any length)
t: [batch] - diffusion timesteps
z_dec: [batch, 64, n_frames] - latent conditioning (any length)
"""
# Embed audio input
print('shape of x: ', x.shape, 'shape of z_dec: ', z_dec.shape)
x = self.embed_audio(x) # [batch, c0, samples]
print('shape of x: ', x.shape)
# Add latent conditioning
if z_dec is not None:
x = self.condition_with_latents(x, z_dec)
print('shape of x: ', x.shape)
# Embed timestep
if t is None:
t = torch.zeros(x.shape[0], device=x.device)
t = self.embed_time(t) # [batch, t_dim]
# Downsampling with skip connections
skips = [x]
for down in self.down:
for block in down:
x = block(x, t)
skips.append(x)
# Middle layers
for mid in self.mid:
x = mid(x, t)
# Upsampling with skip connections
for up in self.up[::-1]:
for block in up:
if isinstance(block, AudioConvResblock):
x = torch.cat([x, skips.pop()], dim=1)
x = block(x, t)
# Output
return self.output(x)