Spaces:
Sleeping
Sleeping
| # https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from models import register | |
| class PositionalEmbedding(nn.Module): | |
| def __init__(self, pe_dim=320, out_dim=1280, max_positions=10000, endpoint=True): | |
| super().__init__() | |
| self.num_channels = pe_dim | |
| self.max_positions = max_positions | |
| self.endpoint = endpoint | |
| self.f_1 = nn.Linear(pe_dim, out_dim) | |
| self.f_2 = nn.Linear(out_dim, out_dim) | |
| def forward(self, x): | |
| freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) | |
| freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) | |
| freqs = (1 / self.max_positions) ** freqs | |
| x = x.ger(freqs.to(x.dtype)) | |
| x = torch.cat([x.cos(), x.sin()], dim=1) | |
| x = self.f_1(x) | |
| x = F.silu(x) | |
| return self.f_2(x) | |
| class AudioEmbedding(nn.Module): | |
| """1D convolution for audio input embedding""" | |
| def __init__(self, in_channels, out_channels=320, kernel_size=3) -> None: | |
| super().__init__() | |
| self.f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2) | |
| def forward(self, x) -> torch.Tensor: | |
| return self.f(x) | |
| class AudioUnembedding(nn.Module): | |
| """1D convolution for audio output""" | |
| def __init__(self, in_channels=320, out_channels=1, kernel_size=3) -> None: | |
| super().__init__() | |
| self.gn = nn.GroupNorm(32, in_channels) | |
| self.f = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2) | |
| def forward(self, x) -> torch.Tensor: | |
| return self.f(F.silu(self.gn(x))) | |
| class AudioConvResblock(nn.Module): | |
| """1D Residual block for audio""" | |
| def __init__(self, in_features, out_features, t_dim, kernel_size=3) -> None: | |
| super().__init__() | |
| self.f_t = nn.Linear(t_dim, out_features * 2) | |
| self.gn_1 = nn.GroupNorm(32, in_features) | |
| self.f_1 = nn.Conv1d(in_features, out_features, kernel_size=kernel_size, padding=kernel_size//2) | |
| self.gn_2 = nn.GroupNorm(32, out_features) | |
| self.f_2 = nn.Conv1d(out_features, out_features, kernel_size=kernel_size, padding=kernel_size//2) | |
| skip_conv = in_features != out_features | |
| self.f_s = ( | |
| nn.Conv1d(in_features, out_features, kernel_size=1, padding=0) | |
| if skip_conv | |
| else nn.Identity() | |
| ) | |
| def forward(self, x, t): | |
| x_skip = x | |
| t = self.f_t(F.silu(t)) | |
| t = t.chunk(2, dim=1) | |
| t_1 = t[0].unsqueeze(dim=2) + 1 # [batch, channels, 1] | |
| t_2 = t[1].unsqueeze(dim=2) # [batch, channels, 1] | |
| gn_1 = F.silu(self.gn_1(x)) | |
| f_1 = self.f_1(gn_1) | |
| gn_2 = self.gn_2(f_1) | |
| return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2)) | |
| class AudioDownsample(nn.Module): | |
| """1D downsampling for audio""" | |
| def __init__(self, in_channels, t_dim, downsample_factor=2) -> None: | |
| super().__init__() | |
| self.f_t = nn.Linear(t_dim, in_channels * 2) | |
| self.downsample_factor = downsample_factor | |
| self.gn_1 = nn.GroupNorm(32, in_channels) | |
| self.f_1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1) | |
| self.gn_2 = nn.GroupNorm(32, in_channels) | |
| self.f_2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1) | |
| def forward(self, x, t) -> torch.Tensor: | |
| x_skip = x | |
| t = self.f_t(F.silu(t)) | |
| t_1, t_2 = t.chunk(2, dim=1) | |
| t_1 = t_1.unsqueeze(2) + 1 | |
| t_2 = t_2.unsqueeze(2) | |
| gn_1 = F.silu(self.gn_1(x)) | |
| # 1D average pooling | |
| avg_pool1d = F.avg_pool1d(gn_1, kernel_size=self.downsample_factor) | |
| f_1 = self.f_1(avg_pool1d) | |
| gn_2 = self.gn_2(f_1) | |
| f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) | |
| return f_2 + F.avg_pool1d(x_skip, kernel_size=self.downsample_factor) | |
| class AudioUpsample(nn.Module): | |
| """1D upsampling for audio""" | |
| def __init__(self, in_channels, t_dim, upsample_factor=2) -> None: | |
| super().__init__() | |
| self.f_t = nn.Linear(t_dim, in_channels * 2) | |
| self.upsample_factor = upsample_factor | |
| self.gn_1 = nn.GroupNorm(32, in_channels) | |
| self.f_1 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1) | |
| self.gn_2 = nn.GroupNorm(32, in_channels) | |
| self.f_2 = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1) | |
| def forward(self, x, t) -> torch.Tensor: | |
| x_skip = x | |
| t = self.f_t(F.silu(t)) | |
| t_1, t_2 = t.chunk(2, dim=1) | |
| t_1 = t_1.unsqueeze(2) + 1 | |
| t_2 = t_2.unsqueeze(2) | |
| gn_1 = F.silu(self.gn_1(x)) | |
| # 1D interpolation upsampling | |
| upsample = F.interpolate(gn_1, scale_factor=self.upsample_factor, mode='linear') | |
| f_1 = self.f_1(upsample) | |
| gn_2 = self.gn_2(f_1) | |
| f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) | |
| return f_2 + F.interpolate(x_skip, scale_factor=self.upsample_factor, mode='linear') | |
| class AudioDiffusionUNet(nn.Module): | |
| """ | |
| 1D UNet for audio diffusion with dynamic latent conditioning | |
| Handles: | |
| - x: [batch, 1, samples] - audio waveform (dynamic length) | |
| - z_dec: [batch, 64, n_frames] - latent conditioning (dynamic length) | |
| """ | |
| def __init__( | |
| self, | |
| in_channels=1, # Audio channels (mono=1, stereo=2) | |
| z_dec_channels=64, # Latent conditioning channels | |
| c0=128, c1=256, c2=512, # Channel progression (smaller than image version) | |
| pe_dim=320, | |
| t_dim=1280, | |
| kernel_size=3 | |
| ) -> None: | |
| super().__init__() | |
| # Store for dynamic conditioning | |
| self.z_dec_channels = z_dec_channels | |
| # Audio input embedding | |
| self.embed_audio = AudioEmbedding( | |
| in_channels=in_channels, | |
| out_channels=c0, | |
| kernel_size=kernel_size | |
| ) | |
| # Time embedding | |
| self.embed_time = PositionalEmbedding(pe_dim=pe_dim, out_dim=t_dim) | |
| # Latent conditioning projection | |
| if z_dec_channels is not None: | |
| self.z_dec_proj = nn.Conv1d(z_dec_channels, c0, kernel_size=1) | |
| # Downsampling path | |
| down_0 = nn.ModuleList([ | |
| AudioConvResblock(c0, c0, t_dim, kernel_size), | |
| AudioConvResblock(c0, c0, t_dim, kernel_size), | |
| AudioConvResblock(c0, c0, t_dim, kernel_size), | |
| AudioDownsample(c0, t_dim), | |
| ]) | |
| down_1 = nn.ModuleList([ | |
| AudioConvResblock(c0, c1, t_dim, kernel_size), | |
| AudioConvResblock(c1, c1, t_dim, kernel_size), | |
| AudioConvResblock(c1, c1, t_dim, kernel_size), | |
| AudioDownsample(c1, t_dim), | |
| ]) | |
| down_2 = nn.ModuleList([ | |
| AudioConvResblock(c1, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2, c2, t_dim, kernel_size), | |
| AudioDownsample(c2, t_dim), | |
| ]) | |
| down_3 = nn.ModuleList([ | |
| AudioConvResblock(c2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2, c2, t_dim, kernel_size), | |
| ]) | |
| self.down = nn.ModuleList([down_0, down_1, down_2, down_3]) | |
| # Middle layers | |
| self.mid = nn.ModuleList([ | |
| AudioConvResblock(c2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2, c2, t_dim, kernel_size), | |
| ]) | |
| # Upsampling path | |
| up_3 = nn.ModuleList([ | |
| AudioConvResblock(c2 * 2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2 * 2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2 * 2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2 * 2, c2, t_dim, kernel_size), | |
| AudioUpsample(c2, t_dim), | |
| ]) | |
| up_2 = nn.ModuleList([ | |
| AudioConvResblock(c2 * 2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2 * 2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2 * 2, c2, t_dim, kernel_size), | |
| AudioConvResblock(c2 + c1, c2, t_dim, kernel_size), | |
| AudioUpsample(c2, t_dim), | |
| ]) | |
| up_1 = nn.ModuleList([ | |
| AudioConvResblock(c2 + c1, c1, t_dim, kernel_size), | |
| AudioConvResblock(c1 * 2, c1, t_dim, kernel_size), | |
| AudioConvResblock(c1 * 2, c1, t_dim, kernel_size), | |
| AudioConvResblock(c0 + c1, c1, t_dim, kernel_size), | |
| AudioUpsample(c1, t_dim), | |
| ]) | |
| up_0 = nn.ModuleList([ | |
| AudioConvResblock(c0 + c1, c0, t_dim, kernel_size), | |
| AudioConvResblock(c0 * 2, c0, t_dim, kernel_size), | |
| AudioConvResblock(c0 * 2, c0, t_dim, kernel_size), | |
| AudioConvResblock(c0 * 2, c0, t_dim, kernel_size), | |
| ]) | |
| self.up = nn.ModuleList([up_0, up_1, up_2, up_3]) | |
| # Output layer | |
| self.output = AudioUnembedding(in_channels=c0, out_channels=in_channels) | |
| def get_last_layer_weight(self): | |
| return self.output.f.weight | |
| def condition_with_latents(self, x, z_dec): | |
| """ | |
| Add latent conditioning to audio features | |
| Args: | |
| x: [batch, c0, audio_samples] - audio features | |
| z_dec: [batch, 64, n_frames] - latent conditioning | |
| Returns: | |
| x: [batch, c0, audio_samples] - conditioned features | |
| """ | |
| if z_dec is None: | |
| return x | |
| # Project latents to same channel dimension as audio features | |
| z_proj = self.z_dec_proj(z_dec) # [batch, c0, n_frames] | |
| # Interpolate latents to match audio length | |
| if z_proj.shape[-1] != x.shape[-1]: | |
| z_proj = F.interpolate( | |
| z_proj, | |
| size=x.shape[-1], | |
| mode='linear' # or 'linear' for smoother interpolation | |
| ) | |
| print('shape of z_proj: ', z_proj.shape) | |
| # Add latent conditioning to audio features | |
| return torch.cat([x, z_proj], dim=1) | |
| def forward(self, x, t=None, z_dec=None) -> torch.Tensor: | |
| """ | |
| Forward pass | |
| Args: | |
| x: [batch, 1, samples] - audio waveform (any length) | |
| t: [batch] - diffusion timesteps | |
| z_dec: [batch, 64, n_frames] - latent conditioning (any length) | |
| """ | |
| # Embed audio input | |
| print('shape of x: ', x.shape, 'shape of z_dec: ', z_dec.shape) | |
| x = self.embed_audio(x) # [batch, c0, samples] | |
| print('shape of x: ', x.shape) | |
| # Add latent conditioning | |
| if z_dec is not None: | |
| x = self.condition_with_latents(x, z_dec) | |
| print('shape of x: ', x.shape) | |
| # Embed timestep | |
| if t is None: | |
| t = torch.zeros(x.shape[0], device=x.device) | |
| t = self.embed_time(t) # [batch, t_dim] | |
| # Downsampling with skip connections | |
| skips = [x] | |
| for down in self.down: | |
| for block in down: | |
| x = block(x, t) | |
| skips.append(x) | |
| # Middle layers | |
| for mid in self.mid: | |
| x = mid(x, t) | |
| # Upsampling with skip connections | |
| for up in self.up[::-1]: | |
| for block in up: | |
| if isinstance(block, AudioConvResblock): | |
| x = torch.cat([x, skips.pop()], dim=1) | |
| x = block(x, t) | |
| # Output | |
| return self.output(x) | |