|
|
from typing import Sequence |
|
|
import random |
|
|
from typing import Any |
|
|
from pathlib import Path |
|
|
|
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import diffusers.schedulers as noise_schedulers |
|
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
|
from models.autoencoder.autoencoder_base import AutoEncoderBase |
|
|
from models.content_encoder.content_encoder import ContentEncoder |
|
|
from models.content_adapter import ContentAdapterBase |
|
|
from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase |
|
|
from utils.torch_utilities import ( |
|
|
create_alignment_path, create_mask_from_length, loss_with_mask, |
|
|
trim_or_pad_length |
|
|
) |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
class DiffusionMixin: |
|
|
def __init__( |
|
|
self, |
|
|
noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", |
|
|
snr_gamma: float = None, |
|
|
cfg_drop_ratio: float = 0.2 |
|
|
) -> None: |
|
|
self.noise_scheduler_name = noise_scheduler_name |
|
|
self.snr_gamma = snr_gamma |
|
|
self.classifier_free_guidance = cfg_drop_ratio > 0.0 |
|
|
self.cfg_drop_ratio = cfg_drop_ratio |
|
|
self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained( |
|
|
self.noise_scheduler_name, subfolder="scheduler" |
|
|
) |
|
|
|
|
|
def compute_snr(self, timesteps) -> torch.Tensor: |
|
|
""" |
|
|
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
|
|
""" |
|
|
alphas_cumprod = self.noise_scheduler.alphas_cumprod |
|
|
sqrt_alphas_cumprod = alphas_cumprod**0.5 |
|
|
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5 |
|
|
|
|
|
|
|
|
|
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device |
|
|
)[timesteps].float() |
|
|
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
|
|
alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
|
|
|
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
|
|
device=timesteps.device |
|
|
)[timesteps].float() |
|
|
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., |
|
|
None] |
|
|
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
|
|
|
|
|
|
|
|
snr = (alpha / sigma)**2 |
|
|
return snr |
|
|
|
|
|
def get_timesteps( |
|
|
self, |
|
|
batch_size: int, |
|
|
device: torch.device, |
|
|
training: bool = True |
|
|
) -> torch.Tensor: |
|
|
if training: |
|
|
timesteps = torch.randint( |
|
|
0, |
|
|
self.noise_scheduler.config.num_train_timesteps, |
|
|
(batch_size, ), |
|
|
device=device |
|
|
) |
|
|
else: |
|
|
|
|
|
timesteps = (self.noise_scheduler.config.num_train_timesteps // |
|
|
2) * torch.ones((batch_size, ), |
|
|
dtype=torch.int64, |
|
|
device=device) |
|
|
|
|
|
timesteps = timesteps.long() |
|
|
return timesteps |
|
|
|
|
|
def get_target( |
|
|
self, latent: torch.Tensor, noise: torch.Tensor, |
|
|
timesteps: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Get the target for loss depending on the prediction type |
|
|
""" |
|
|
if self.noise_scheduler.config.prediction_type == "epsilon": |
|
|
target = noise |
|
|
elif self.noise_scheduler.config.prediction_type == "v_prediction": |
|
|
target = self.noise_scheduler.get_velocity( |
|
|
latent, noise, timesteps |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" |
|
|
) |
|
|
return target |
|
|
|
|
|
def loss_with_snr( |
|
|
self, pred: torch.Tensor, target: torch.Tensor, |
|
|
timesteps: torch.Tensor, mask: torch.Tensor, |
|
|
loss_reduce: bool = True, |
|
|
) -> torch.Tensor: |
|
|
if self.snr_gamma is None: |
|
|
loss = F.mse_loss(pred.float(), target.float(), reduction="none") |
|
|
loss = loss_with_mask(loss, mask, reduce=loss_reduce) |
|
|
else: |
|
|
|
|
|
|
|
|
snr = self.compute_snr(timesteps) |
|
|
mse_loss_weights = torch.stack( |
|
|
[ |
|
|
snr, |
|
|
self.snr_gamma * torch.ones_like(timesteps), |
|
|
], |
|
|
dim=1, |
|
|
).min(dim=1)[0] |
|
|
|
|
|
mse_loss_weights = mse_loss_weights / snr |
|
|
loss = F.mse_loss(pred.float(), target.float(), reduction="none") |
|
|
loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights |
|
|
if loss_reduce: |
|
|
loss = loss.mean() |
|
|
return loss |
|
|
|
|
|
def rescale_cfg( |
|
|
self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor, |
|
|
guidance_rescale: float |
|
|
): |
|
|
""" |
|
|
Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
|
""" |
|
|
std_cond = pred_cond.std( |
|
|
dim=list(range(1, pred_cond.ndim)), keepdim=True |
|
|
) |
|
|
std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True) |
|
|
|
|
|
pred_rescaled = pred_cfg * (std_cond / std_cfg) |
|
|
pred_cfg = guidance_rescale * pred_rescaled + ( |
|
|
1 - guidance_rescale |
|
|
) * pred_cfg |
|
|
return pred_cfg |
|
|
|
|
|
class CrossAttentionAudioDiffusion( |
|
|
LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase, |
|
|
DiffusionMixin |
|
|
): |
|
|
def __init__( |
|
|
self, |
|
|
autoencoder: AutoEncoderBase, |
|
|
content_encoder: ContentEncoder, |
|
|
content_adapter: ContentAdapterBase, |
|
|
backbone: nn.Module, |
|
|
duration_offset: float = 1.0, |
|
|
noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", |
|
|
snr_gamma: float = None, |
|
|
cfg_drop_ratio: float = 0.2, |
|
|
): |
|
|
nn.Module.__init__(self) |
|
|
DiffusionMixin.__init__( |
|
|
self, noise_scheduler_name, snr_gamma, cfg_drop_ratio |
|
|
) |
|
|
|
|
|
self.autoencoder = autoencoder |
|
|
for param in self.autoencoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.content_encoder = content_encoder |
|
|
self.content_encoder.audio_encoder.model = self.autoencoder |
|
|
self.content_adapter = content_adapter |
|
|
self.backbone = backbone |
|
|
self.duration_offset = duration_offset |
|
|
self.dummy_param = nn.Parameter(torch.empty(0)) |
|
|
|
|
|
def forward( |
|
|
self, content: list[Any], task: list[str], waveform: torch.Tensor, |
|
|
waveform_lengths: torch.Tensor, instruction: torch.Tensor, |
|
|
instruction_lengths: Sequence[int], **kwargs |
|
|
): |
|
|
device = self.dummy_param.device |
|
|
num_train_timesteps = self.noise_scheduler.config.num_train_timesteps |
|
|
self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) |
|
|
|
|
|
self.autoencoder.eval() |
|
|
with torch.no_grad(): |
|
|
latent, latent_mask = self.autoencoder.encode( |
|
|
waveform.unsqueeze(1), waveform_lengths |
|
|
) |
|
|
|
|
|
content_output: dict[ |
|
|
str, torch.Tensor] = self.content_encoder.encode_content( |
|
|
content, task, device=device |
|
|
) |
|
|
content, content_mask = content_output["content"], content_output[ |
|
|
"content_mask"] |
|
|
instruction_mask = create_mask_from_length(instruction_lengths) |
|
|
content, content_mask, global_duration_pred, _ = \ |
|
|
self.content_adapter(content, content_mask, instruction, instruction_mask) |
|
|
global_duration_target = torch.log( |
|
|
latent_mask.sum(1) / self.autoencoder.latent_token_rate + |
|
|
self.duration_offset |
|
|
) |
|
|
global_duration_loss = F.mse_loss( |
|
|
global_duration_target, global_duration_pred |
|
|
) |
|
|
|
|
|
if self.training and self.classifier_free_guidance: |
|
|
mask_indices = [ |
|
|
k for k in range(len(waveform)) |
|
|
if random.random() < self.cfg_drop_ratio |
|
|
] |
|
|
if len(mask_indices) > 0: |
|
|
content[mask_indices] = 0 |
|
|
|
|
|
batch_size = latent.shape[0] |
|
|
timesteps = self.get_timesteps(batch_size, device, self.training) |
|
|
noise = torch.randn_like(latent) |
|
|
noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) |
|
|
target = self.get_target(latent, noise, timesteps) |
|
|
|
|
|
pred: torch.Tensor = self.backbone( |
|
|
x=noisy_latent, |
|
|
timesteps=timesteps, |
|
|
context=content, |
|
|
x_mask=latent_mask, |
|
|
context_mask=content_mask |
|
|
) |
|
|
|
|
|
pred = pred.transpose(1, self.autoencoder.time_dim) |
|
|
target = target.transpose(1, self.autoencoder.time_dim) |
|
|
diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) |
|
|
|
|
|
return { |
|
|
"diff_loss": diff_loss, |
|
|
"global_duration_loss": global_duration_loss, |
|
|
} |
|
|
|
|
|
@torch.no_grad() |
|
|
def inference( |
|
|
self, |
|
|
content: list[Any], |
|
|
condition: list[Any], |
|
|
task: list[str], |
|
|
instruction: torch.Tensor, |
|
|
instruction_lengths: Sequence[int], |
|
|
scheduler: SchedulerMixin, |
|
|
num_steps: int = 20, |
|
|
guidance_scale: float = 3.0, |
|
|
guidance_rescale: float = 0.0, |
|
|
disable_progress: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
device = self.dummy_param.device |
|
|
classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
content_output: dict[ |
|
|
str, torch.Tensor] = self.content_encoder.encode_content( |
|
|
content, task, device=device |
|
|
) |
|
|
content, content_mask = content_output["content"], content_output[ |
|
|
"content_mask"] |
|
|
|
|
|
instruction_mask = create_mask_from_length(instruction_lengths) |
|
|
content, content_mask, global_duration_pred, _ = \ |
|
|
self.content_adapter(content, content_mask, instruction, instruction_mask) |
|
|
batch_size = content.size(0) |
|
|
|
|
|
if classifier_free_guidance: |
|
|
uncond_content = torch.zeros_like(content) |
|
|
uncond_content_mask = content_mask.detach().clone() |
|
|
content = torch.cat([uncond_content, content]) |
|
|
content_mask = torch.cat([uncond_content_mask, content_mask]) |
|
|
|
|
|
scheduler.set_timesteps(num_steps, device=device) |
|
|
timesteps = scheduler.timesteps |
|
|
|
|
|
global_duration_pred = torch.exp( |
|
|
global_duration_pred |
|
|
) - self.duration_offset |
|
|
global_duration_pred *= self.autoencoder.latent_token_rate |
|
|
global_duration_pred = torch.round(global_duration_pred) |
|
|
|
|
|
latent_shape = tuple( |
|
|
int(global_duration_pred.max().item()) if dim is None else dim |
|
|
for dim in self.autoencoder.latent_shape |
|
|
) |
|
|
latent = self.prepare_latent( |
|
|
batch_size, scheduler, latent_shape, content.dtype, device |
|
|
) |
|
|
latent_mask = create_mask_from_length(global_duration_pred).to( |
|
|
content_mask.device |
|
|
) |
|
|
if classifier_free_guidance: |
|
|
latent_mask = torch.cat([latent_mask, latent_mask]) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_steps * scheduler.order |
|
|
progress_bar = tqdm(range(num_steps), disable=disable_progress) |
|
|
|
|
|
for i, timestep in enumerate(timesteps): |
|
|
|
|
|
latent_input = torch.cat([latent, latent] |
|
|
) if classifier_free_guidance else latent |
|
|
latent_input = scheduler.scale_model_input(latent_input, timestep) |
|
|
|
|
|
noise_pred = self.backbone( |
|
|
x=latent_input, |
|
|
x_mask=latent_mask, |
|
|
timesteps=timestep, |
|
|
context=content, |
|
|
context_mask=content_mask, |
|
|
) |
|
|
|
|
|
|
|
|
if classifier_free_guidance: |
|
|
noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
|
noise_pred_content - noise_pred_uncond |
|
|
) |
|
|
if guidance_rescale != 0.0: |
|
|
noise_pred = self.rescale_cfg( |
|
|
noise_pred_content, noise_pred, guidance_rescale |
|
|
) |
|
|
|
|
|
|
|
|
latent = scheduler.step(noise_pred, timestep, latent).prev_sample |
|
|
|
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and |
|
|
(i + 1) % scheduler.order == 0): |
|
|
progress_bar.update(1) |
|
|
|
|
|
waveform = self.autoencoder.decode(latent) |
|
|
|
|
|
return waveform |
|
|
|
|
|
def prepare_latent( |
|
|
self, batch_size: int, scheduler: SchedulerMixin, |
|
|
latent_shape: Sequence[int], dtype: torch.dtype, device: str |
|
|
): |
|
|
shape = (batch_size, *latent_shape) |
|
|
latent = randn_tensor( |
|
|
shape, generator=None, device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
latent = latent * scheduler.init_noise_sigma |
|
|
return latent |
|
|
|
|
|
class SingleTaskCrossAttentionAudioDiffusion(CrossAttentionAudioDiffusion |
|
|
): |
|
|
def __init__( |
|
|
self, |
|
|
autoencoder: AutoEncoderBase, |
|
|
content_encoder: ContentEncoder, |
|
|
backbone: nn.Module, |
|
|
pretrained_ckpt: str | Path = None, |
|
|
noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", |
|
|
snr_gamma: float = None, |
|
|
cfg_drop_ratio: float = 0.2, |
|
|
): |
|
|
nn.Module.__init__(self) |
|
|
DiffusionMixin.__init__( |
|
|
self, noise_scheduler_name, snr_gamma, cfg_drop_ratio |
|
|
) |
|
|
|
|
|
self.autoencoder = autoencoder |
|
|
for param in self.autoencoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
self.backbone = backbone |
|
|
if pretrained_ckpt is not None: |
|
|
pretrained_state_dict = load_file(pretrained_ckpt) |
|
|
self.load_pretrained(pretrained_state_dict) |
|
|
|
|
|
self.content_encoder = content_encoder |
|
|
|
|
|
self.dummy_param = nn.Parameter(torch.empty(0)) |
|
|
|
|
|
def forward( |
|
|
self, content: list[Any], condition: list[Any], task: list[str], waveform: torch.Tensor, |
|
|
waveform_lengths: torch.Tensor, loss_reduce: bool = True, **kwargs |
|
|
): |
|
|
loss_reduce = self.training or (loss_reduce and not self.training) |
|
|
device = self.dummy_param.device |
|
|
num_train_timesteps = self.noise_scheduler.config.num_train_timesteps |
|
|
self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) |
|
|
|
|
|
self.autoencoder.eval() |
|
|
with torch.no_grad(): |
|
|
latent, latent_mask = self.autoencoder.encode( |
|
|
waveform.unsqueeze(1), waveform_lengths |
|
|
) |
|
|
|
|
|
content_output: dict[ |
|
|
str, torch.Tensor] = self.content_encoder.encode_content( |
|
|
content, task, device=device |
|
|
) |
|
|
content, content_mask = content_output["content"], content_output[ |
|
|
"content_mask"] |
|
|
|
|
|
if self.training and self.classifier_free_guidance: |
|
|
mask_indices = [ |
|
|
k for k in range(len(waveform)) |
|
|
if random.random() < self.cfg_drop_ratio |
|
|
] |
|
|
if len(mask_indices) > 0: |
|
|
content[mask_indices] = 0 |
|
|
|
|
|
batch_size = latent.shape[0] |
|
|
timesteps = self.get_timesteps(batch_size, device, self.training) |
|
|
noise = torch.randn_like(latent) |
|
|
noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) |
|
|
target = self.get_target(latent, noise, timesteps) |
|
|
|
|
|
pred: torch.Tensor = self.backbone( |
|
|
x=noisy_latent, |
|
|
timesteps=timesteps, |
|
|
context=content, |
|
|
x_mask=latent_mask, |
|
|
context_mask=content_mask |
|
|
) |
|
|
|
|
|
pred = pred.transpose(1, self.autoencoder.time_dim) |
|
|
target = target.transpose(1, self.autoencoder.time_dim) |
|
|
diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask, loss_reduce=loss_reduce) |
|
|
|
|
|
return { |
|
|
"diff_loss": diff_loss, |
|
|
} |
|
|
|
|
|
@torch.no_grad() |
|
|
def inference( |
|
|
self, |
|
|
content: list[Any], |
|
|
condition: list[Any], |
|
|
task: list[str], |
|
|
scheduler: SchedulerMixin, |
|
|
latent_shape: Sequence[int], |
|
|
num_steps: int = 20, |
|
|
guidance_scale: float = 3.0, |
|
|
guidance_rescale: float = 0.0, |
|
|
disable_progress: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
device = self.dummy_param.device |
|
|
classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
content_output: dict[ |
|
|
str, torch.Tensor] = self.content_encoder.encode_content( |
|
|
content, task, device=device |
|
|
) |
|
|
content, content_mask = content_output["content"], content_output[ |
|
|
"content_mask"] |
|
|
batch_size = content.size(0) |
|
|
|
|
|
if classifier_free_guidance: |
|
|
uncond_content = torch.zeros_like(content) |
|
|
uncond_content_mask = content_mask.detach().clone() |
|
|
content = torch.cat([uncond_content, content]) |
|
|
content_mask = torch.cat([uncond_content_mask, content_mask]) |
|
|
|
|
|
scheduler.set_timesteps(num_steps, device=device) |
|
|
timesteps = scheduler.timesteps |
|
|
|
|
|
latent = self.prepare_latent( |
|
|
batch_size, scheduler, latent_shape, content.dtype, device |
|
|
) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_steps * scheduler.order |
|
|
progress_bar = tqdm(range(num_steps), disable=disable_progress) |
|
|
|
|
|
for i, timestep in enumerate(timesteps): |
|
|
|
|
|
latent_input = torch.cat([latent, latent] |
|
|
) if classifier_free_guidance else latent |
|
|
latent_input = scheduler.scale_model_input(latent_input, timestep) |
|
|
|
|
|
noise_pred = self.backbone( |
|
|
x=latent_input, |
|
|
timesteps=timestep, |
|
|
context=content, |
|
|
context_mask=content_mask, |
|
|
) |
|
|
|
|
|
|
|
|
if classifier_free_guidance: |
|
|
noise_pred_uncond, noise_pred_content = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
|
noise_pred_content - noise_pred_uncond |
|
|
) |
|
|
if guidance_rescale != 0.0: |
|
|
noise_pred = self.rescale_cfg( |
|
|
noise_pred_content, noise_pred, guidance_rescale |
|
|
) |
|
|
|
|
|
|
|
|
latent = scheduler.step(noise_pred, timestep, latent).prev_sample |
|
|
|
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and |
|
|
(i + 1) % scheduler.order == 0): |
|
|
progress_bar.update(1) |
|
|
|
|
|
waveform = self.autoencoder.decode(latent) |
|
|
|
|
|
return waveform |
|
|
|
|
|
|
|
|
class DummyContentAudioDiffusion(CrossAttentionAudioDiffusion): |
|
|
def __init__( |
|
|
self, |
|
|
autoencoder: AutoEncoderBase, |
|
|
content_encoder: ContentEncoder, |
|
|
content_adapter: ContentAdapterBase, |
|
|
backbone: nn.Module, |
|
|
content_dim: int, |
|
|
frame_resolution: float, |
|
|
duration_offset: float = 1.0, |
|
|
noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", |
|
|
snr_gamma: float = None, |
|
|
cfg_drop_ratio: float = 0.2, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
autoencoder: |
|
|
Pretrained audio autoencoder that encodes raw waveforms into latent |
|
|
space and decodes latents back to waveforms. |
|
|
content_encoder: |
|
|
Module that produces content embeddings (e.g., from text, MIDI, or |
|
|
other modalities) used to guide the diffusion. |
|
|
content_adapter (ContentAdapterBase): |
|
|
Adapter module that fuses task instruction embeddings and content embeddings, |
|
|
and performs duration prediction for time-aligned tasks. |
|
|
backbone: |
|
|
U‑Net or Transformer backbone that performs the core denoising |
|
|
operations in latent space. |
|
|
content_dim: |
|
|
Dimension of the content embeddings produced by the `content_encoder` |
|
|
and `content_adapter`. |
|
|
frame_resolution: |
|
|
Time resolution, in seconds, of each content frame when predicting |
|
|
duration alignment. Used when calculating duration loss. |
|
|
duration_offset: |
|
|
A small positive offset (frame number) added to predicted durations |
|
|
to ensure numerical stability of log-scaled duration prediction. |
|
|
noise_scheduler_name: |
|
|
Identifier of the pretrained noise scheduler to use. |
|
|
snr_gamma: |
|
|
Clipping value in min-SNR diffusion loss weighting strategy. |
|
|
cfg_drop_ratio: |
|
|
Probability of dropping the content conditioning during training |
|
|
to support CFG. |
|
|
""" |
|
|
super().__init__( |
|
|
autoencoder=autoencoder, |
|
|
content_encoder=content_encoder, |
|
|
content_adapter=content_adapter, |
|
|
backbone=backbone, |
|
|
duration_offset=duration_offset, |
|
|
noise_scheduler_name=noise_scheduler_name, |
|
|
snr_gamma=snr_gamma, |
|
|
cfg_drop_ratio=cfg_drop_ratio, |
|
|
) |
|
|
self.frame_resolution = frame_resolution |
|
|
self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim)) |
|
|
self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim)) |
|
|
|
|
|
def forward( |
|
|
self, content, duration, task, is_time_aligned, waveform, |
|
|
waveform_lengths, instruction, instruction_lengths, **kwargs |
|
|
): |
|
|
device = self.dummy_param.device |
|
|
num_train_timesteps = self.noise_scheduler.config.num_train_timesteps |
|
|
self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) |
|
|
|
|
|
self.autoencoder.eval() |
|
|
with torch.no_grad(): |
|
|
latent, latent_mask = self.autoencoder.encode( |
|
|
waveform.unsqueeze(1), waveform_lengths |
|
|
) |
|
|
|
|
|
|
|
|
content_output: dict[ |
|
|
str, torch.Tensor] = self.content_encoder.encode_content( |
|
|
content, task, device=device |
|
|
) |
|
|
length_aligned_content = content_output["length_aligned_content"] |
|
|
content, content_mask = content_output["content"], content_output[ |
|
|
"content_mask"] |
|
|
instruction_mask = create_mask_from_length(instruction_lengths) |
|
|
|
|
|
content, content_mask, global_duration_pred, local_duration_pred = \ |
|
|
self.content_adapter(content, content_mask, instruction, instruction_mask) |
|
|
|
|
|
n_frames = torch.round(duration / self.frame_resolution) |
|
|
local_duration_target = torch.log(n_frames + self.duration_offset) |
|
|
global_duration_target = torch.log( |
|
|
latent_mask.sum(1) / self.autoencoder.latent_token_rate + |
|
|
self.duration_offset |
|
|
) |
|
|
|
|
|
|
|
|
if is_time_aligned.sum() > 0: |
|
|
trunc_ta_length = content_mask[is_time_aligned].sum(1).max() |
|
|
else: |
|
|
trunc_ta_length = content.size(1) |
|
|
|
|
|
|
|
|
local_duration_pred = local_duration_pred[:, :trunc_ta_length] |
|
|
ta_content_mask = content_mask[:, :trunc_ta_length] |
|
|
local_duration_target = local_duration_target.to( |
|
|
dtype=local_duration_pred.dtype |
|
|
) |
|
|
local_duration_loss = loss_with_mask( |
|
|
(local_duration_target - local_duration_pred)**2, |
|
|
ta_content_mask, |
|
|
reduce=False |
|
|
) |
|
|
local_duration_loss *= is_time_aligned |
|
|
if is_time_aligned.sum().item() == 0: |
|
|
local_duration_loss *= 0.0 |
|
|
local_duration_loss = local_duration_loss.mean() |
|
|
else: |
|
|
local_duration_loss = local_duration_loss.sum( |
|
|
) / is_time_aligned.sum() |
|
|
|
|
|
|
|
|
global_duration_loss = F.mse_loss( |
|
|
global_duration_target, global_duration_pred |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = latent.shape[0] |
|
|
timesteps = self.get_timesteps(batch_size, device, self.training) |
|
|
noise = torch.randn_like(latent) |
|
|
noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) |
|
|
target = self.get_target(latent, noise, timesteps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_time_aligned.sum() == 0 and \ |
|
|
duration.size(1) < content_mask.size(1): |
|
|
|
|
|
duration = F.pad( |
|
|
duration, (0, content_mask.size(1) - duration.size(1)) |
|
|
) |
|
|
n_latents = torch.round(duration * self.autoencoder.latent_token_rate) |
|
|
|
|
|
helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to( |
|
|
content_mask.device |
|
|
) |
|
|
attn_mask = ta_content_mask.unsqueeze( |
|
|
-1 |
|
|
) * helper_latent_mask.unsqueeze(1) |
|
|
|
|
|
align_path = create_alignment_path(n_latents, attn_mask) |
|
|
time_aligned_content = content[:, :trunc_ta_length] |
|
|
time_aligned_content = torch.matmul( |
|
|
align_path.transpose(1, 2).to(content.dtype), time_aligned_content |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent_length = noisy_latent.size(self.autoencoder.time_dim) |
|
|
time_aligned_content = trim_or_pad_length( |
|
|
time_aligned_content, latent_length, 1 |
|
|
) |
|
|
length_aligned_content = trim_or_pad_length( |
|
|
length_aligned_content, latent_length, 1 |
|
|
) |
|
|
|
|
|
|
|
|
time_aligned_content = time_aligned_content + length_aligned_content |
|
|
time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( |
|
|
time_aligned_content.dtype |
|
|
) |
|
|
|
|
|
context = content |
|
|
context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) |
|
|
|
|
|
context_mask = content_mask.detach().clone() |
|
|
context_mask[is_time_aligned, 1:] = False |
|
|
|
|
|
|
|
|
if is_time_aligned.sum().item() < batch_size: |
|
|
trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() |
|
|
else: |
|
|
trunc_nta_length = content.size(1) |
|
|
context = context[:, :trunc_nta_length] |
|
|
context_mask = context_mask[:, :trunc_nta_length] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.training and self.classifier_free_guidance: |
|
|
mask_indices = [ |
|
|
k for k in range(len(waveform)) |
|
|
if random.random() < self.cfg_drop_ratio |
|
|
] |
|
|
if len(mask_indices) > 0: |
|
|
context[mask_indices] = 0 |
|
|
time_aligned_content[mask_indices] = 0 |
|
|
|
|
|
pred: torch.Tensor = self.backbone( |
|
|
x=noisy_latent, |
|
|
timesteps=timesteps, |
|
|
time_aligned_context=time_aligned_content, |
|
|
context=context, |
|
|
x_mask=latent_mask, |
|
|
context_mask=context_mask |
|
|
) |
|
|
pred = pred.transpose(1, self.autoencoder.time_dim) |
|
|
target = target.transpose(1, self.autoencoder.time_dim) |
|
|
diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) |
|
|
return { |
|
|
"diff_loss": diff_loss, |
|
|
"local_duration_loss": local_duration_loss, |
|
|
"global_duration_loss": global_duration_loss |
|
|
} |
|
|
|
|
|
@torch.no_grad() |
|
|
def inference( |
|
|
self, |
|
|
content: list[Any], |
|
|
condition: list[Any], |
|
|
task: list[str], |
|
|
is_time_aligned: list[bool], |
|
|
instruction: torch.Tensor, |
|
|
instruction_lengths: Sequence[int], |
|
|
scheduler: SchedulerMixin, |
|
|
num_steps: int = 20, |
|
|
guidance_scale: float = 3.0, |
|
|
guidance_rescale: float = 0.0, |
|
|
disable_progress: bool = True, |
|
|
use_gt_duration: bool = False, |
|
|
**kwargs |
|
|
): |
|
|
device = self.dummy_param.device |
|
|
classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
content_output: dict[ |
|
|
str, torch.Tensor] = self.content_encoder.encode_content( |
|
|
content, task, device=device |
|
|
) |
|
|
length_aligned_content = content_output["length_aligned_content"] |
|
|
content, content_mask = content_output["content"], content_output[ |
|
|
"content_mask"] |
|
|
instruction_mask = create_mask_from_length(instruction_lengths) |
|
|
content, content_mask, global_duration_pred, local_duration_pred = \ |
|
|
self.content_adapter(content, content_mask, instruction, instruction_mask) |
|
|
|
|
|
scheduler.set_timesteps(num_steps, device=device) |
|
|
timesteps = scheduler.timesteps |
|
|
batch_size = content.size(0) |
|
|
|
|
|
|
|
|
is_time_aligned = torch.as_tensor(is_time_aligned) |
|
|
if is_time_aligned.sum() > 0: |
|
|
trunc_ta_length = content_mask[is_time_aligned].sum(1).max() |
|
|
else: |
|
|
trunc_ta_length = content.size(1) |
|
|
|
|
|
|
|
|
local_duration_pred = torch.exp(local_duration_pred) * content_mask |
|
|
local_duration_pred = torch.ceil( |
|
|
local_duration_pred |
|
|
) - self.duration_offset |
|
|
local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \ |
|
|
self.autoencoder.latent_token_rate) |
|
|
local_duration_pred = local_duration_pred[:, :trunc_ta_length] |
|
|
|
|
|
if use_gt_duration and "duration" in kwargs: |
|
|
local_duration_pred = torch.round( |
|
|
torch.as_tensor(kwargs["duration"]) * |
|
|
self.autoencoder.latent_token_rate |
|
|
).to(device) |
|
|
|
|
|
|
|
|
global_duration = local_duration_pred.sum(1) |
|
|
global_duration_pred = torch.exp( |
|
|
global_duration_pred |
|
|
) - self.duration_offset |
|
|
global_duration_pred *= self.autoencoder.latent_token_rate |
|
|
global_duration_pred = torch.round(global_duration_pred) |
|
|
global_duration[~is_time_aligned] = global_duration_pred[ |
|
|
~is_time_aligned] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time_aligned_content = content[:, :trunc_ta_length] |
|
|
ta_content_mask = content_mask[:, :trunc_ta_length] |
|
|
latent_mask = create_mask_from_length(global_duration).to( |
|
|
content_mask.device |
|
|
) |
|
|
attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) |
|
|
|
|
|
align_path = create_alignment_path(local_duration_pred, attn_mask) |
|
|
time_aligned_content = torch.matmul( |
|
|
align_path.transpose(1, 2).to(content.dtype), time_aligned_content |
|
|
) |
|
|
time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to( |
|
|
time_aligned_content.dtype |
|
|
) |
|
|
|
|
|
length_aligned_content = trim_or_pad_length( |
|
|
length_aligned_content, time_aligned_content.size(1), 1 |
|
|
) |
|
|
time_aligned_content = time_aligned_content + length_aligned_content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context = content |
|
|
context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype) |
|
|
context_mask = content_mask |
|
|
context_mask[ |
|
|
is_time_aligned, |
|
|
1:] = False |
|
|
|
|
|
if is_time_aligned.sum().item() < batch_size: |
|
|
trunc_nta_length = content_mask[~is_time_aligned].sum(1).max() |
|
|
else: |
|
|
trunc_nta_length = content.size(1) |
|
|
context = context[:, :trunc_nta_length] |
|
|
context_mask = context_mask[:, :trunc_nta_length] |
|
|
|
|
|
if classifier_free_guidance: |
|
|
uncond_time_aligned_content = torch.zeros_like( |
|
|
time_aligned_content |
|
|
) |
|
|
uncond_context = torch.zeros_like(context) |
|
|
uncond_context_mask = context_mask.detach().clone() |
|
|
time_aligned_content = torch.cat([ |
|
|
uncond_time_aligned_content, time_aligned_content |
|
|
]) |
|
|
context = torch.cat([uncond_context, context]) |
|
|
context_mask = torch.cat([uncond_context_mask, context_mask]) |
|
|
latent_mask = torch.cat([ |
|
|
latent_mask, latent_mask.detach().clone() |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent_shape = tuple( |
|
|
int(global_duration.max().item()) if dim is None else dim |
|
|
for dim in self.autoencoder.latent_shape |
|
|
) |
|
|
shape = (batch_size, *latent_shape) |
|
|
latent = randn_tensor( |
|
|
shape, generator=None, device=device, dtype=content.dtype |
|
|
) |
|
|
|
|
|
latent = latent * scheduler.init_noise_sigma |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_steps * scheduler.order |
|
|
progress_bar = tqdm(range(num_steps), disable=disable_progress) |
|
|
|
|
|
|
|
|
|
|
|
for i, timestep in enumerate(timesteps): |
|
|
|
|
|
if classifier_free_guidance: |
|
|
latent_input = torch.cat([latent, latent]) |
|
|
else: |
|
|
latent_input = latent |
|
|
|
|
|
latent_input = scheduler.scale_model_input(latent_input, timestep) |
|
|
noise_pred = self.backbone( |
|
|
x=latent_input, |
|
|
x_mask=latent_mask, |
|
|
timesteps=timestep, |
|
|
time_aligned_context=time_aligned_content, |
|
|
context=context, |
|
|
context_mask=context_mask |
|
|
) |
|
|
|
|
|
if classifier_free_guidance: |
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
|
noise_pred_cond - noise_pred_uncond |
|
|
) |
|
|
if guidance_rescale != 0.0: |
|
|
noise_pred = self.rescale_cfg( |
|
|
noise_pred_cond, noise_pred, guidance_rescale |
|
|
) |
|
|
|
|
|
|
|
|
latent = scheduler.step(noise_pred, timestep, latent).prev_sample |
|
|
|
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and |
|
|
(i + 1) % scheduler.order == 0): |
|
|
progress_bar.update(1) |
|
|
|
|
|
progress_bar.close() |
|
|
|
|
|
|
|
|
waveform = self.autoencoder.decode(latent) |
|
|
return waveform |
|
|
|
|
|
|
|
|
class DoubleContentAudioDiffusion(CrossAttentionAudioDiffusion): |
|
|
def __init__( |
|
|
self, |
|
|
autoencoder: AutoEncoderBase, |
|
|
content_encoder: ContentEncoder, |
|
|
content_adapter: nn.Module, |
|
|
backbone: nn.Module, |
|
|
content_dim: int, |
|
|
frame_resolution: float, |
|
|
duration_offset: float = 1.0, |
|
|
noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1", |
|
|
snr_gamma: float = None, |
|
|
cfg_drop_ratio: float = 0.2, |
|
|
): |
|
|
super().__init__( |
|
|
autoencoder=autoencoder, |
|
|
content_encoder=content_encoder, |
|
|
content_adapter=content_adapter, |
|
|
backbone=backbone, |
|
|
duration_offset=duration_offset, |
|
|
noise_scheduler_name=noise_scheduler_name, |
|
|
snr_gamma=snr_gamma, |
|
|
cfg_drop_ratio=cfg_drop_ratio |
|
|
) |
|
|
self.frame_resolution = frame_resolution |
|
|
|
|
|
def forward( |
|
|
self, content, duration, task, is_time_aligned, waveform, |
|
|
waveform_lengths, instruction, instruction_lengths, **kwargs |
|
|
): |
|
|
device = self.dummy_param.device |
|
|
num_train_timesteps = self.noise_scheduler.config.num_train_timesteps |
|
|
self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) |
|
|
|
|
|
self.autoencoder.eval() |
|
|
with torch.no_grad(): |
|
|
latent, latent_mask = self.autoencoder.encode( |
|
|
waveform.unsqueeze(1), waveform_lengths |
|
|
) |
|
|
|
|
|
content_output: dict[ |
|
|
str, torch.Tensor] = self.content_encoder.encode_content( |
|
|
content, task, device=device |
|
|
) |
|
|
length_aligned_content = content_output["length_aligned_content"] |
|
|
content, content_mask = content_output["content"], content_output[ |
|
|
"content_mask"] |
|
|
context_mask = content_mask.detach() |
|
|
instruction_mask = create_mask_from_length(instruction_lengths) |
|
|
|
|
|
content, content_mask, global_duration_pred, local_duration_pred = \ |
|
|
self.content_adapter(content, content_mask, instruction, instruction_mask) |
|
|
|
|
|
|
|
|
|
|
|
n_frames = torch.round(duration / self.frame_resolution) |
|
|
local_duration_target = torch.log(n_frames + self.duration_offset) |
|
|
global_duration_target = torch.log( |
|
|
latent_mask.sum(1) / self.autoencoder.latent_token_rate + |
|
|
self.duration_offset |
|
|
) |
|
|
|
|
|
if is_time_aligned.sum() > 0: |
|
|
trunc_ta_length = content_mask[is_time_aligned].sum(1).max() |
|
|
else: |
|
|
trunc_ta_length = content.size(1) |
|
|
|
|
|
local_duration_pred = local_duration_pred[:, :trunc_ta_length] |
|
|
ta_content_mask = content_mask[:, :trunc_ta_length] |
|
|
local_duration_target = local_duration_target.to( |
|
|
dtype=local_duration_pred.dtype |
|
|
) |
|
|
local_duration_loss = loss_with_mask( |
|
|
(local_duration_target - local_duration_pred)**2, |
|
|
ta_content_mask, |
|
|
reduce=False |
|
|
) |
|
|
local_duration_loss *= is_time_aligned |
|
|
if is_time_aligned.sum().item() == 0: |
|
|
local_duration_loss *= 0.0 |
|
|
local_duration_loss = local_duration_loss.mean() |
|
|
else: |
|
|
local_duration_loss = local_duration_loss.sum( |
|
|
) / is_time_aligned.sum() |
|
|
|
|
|
|
|
|
global_duration_loss = F.mse_loss( |
|
|
global_duration_target, global_duration_pred |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
batch_size = latent.shape[0] |
|
|
timesteps = self.get_timesteps(batch_size, device, self.training) |
|
|
noise = torch.randn_like(latent) |
|
|
noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps) |
|
|
target = self.get_target(latent, noise, timesteps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_time_aligned.sum() == 0 and \ |
|
|
duration.size(1) < content_mask.size(1): |
|
|
|
|
|
duration = F.pad( |
|
|
duration, (0, content_mask.size(1) - duration.size(1)) |
|
|
) |
|
|
n_latents = torch.round(duration * self.autoencoder.latent_token_rate) |
|
|
helper_latent_mask = create_mask_from_length(n_latents.sum(1)).to( |
|
|
content_mask.device |
|
|
) |
|
|
attn_mask = ta_content_mask.unsqueeze( |
|
|
-1 |
|
|
) * helper_latent_mask.unsqueeze(1) |
|
|
align_path = create_alignment_path(n_latents, attn_mask) |
|
|
time_aligned_content = content[:, :trunc_ta_length] |
|
|
time_aligned_content = torch.matmul( |
|
|
align_path.transpose(1, 2).to(content.dtype), time_aligned_content |
|
|
) |
|
|
|
|
|
latent_length = noisy_latent.size(self.autoencoder.time_dim) |
|
|
time_aligned_content = trim_or_pad_length( |
|
|
time_aligned_content, latent_length, 1 |
|
|
) |
|
|
length_aligned_content = trim_or_pad_length( |
|
|
length_aligned_content, latent_length, 1 |
|
|
) |
|
|
time_aligned_content = time_aligned_content + length_aligned_content |
|
|
context = content |
|
|
|
|
|
|
|
|
|
|
|
if self.training and self.classifier_free_guidance: |
|
|
mask_indices = [ |
|
|
k for k in range(len(waveform)) |
|
|
if random.random() < self.cfg_drop_ratio |
|
|
] |
|
|
if len(mask_indices) > 0: |
|
|
context[mask_indices] = 0 |
|
|
time_aligned_content[mask_indices] = 0 |
|
|
|
|
|
pred: torch.Tensor = self.backbone( |
|
|
x=noisy_latent, |
|
|
timesteps=timesteps, |
|
|
time_aligned_context=time_aligned_content, |
|
|
context=context, |
|
|
x_mask=latent_mask, |
|
|
context_mask=context_mask, |
|
|
) |
|
|
pred = pred.transpose(1, self.autoencoder.time_dim) |
|
|
target = target.transpose(1, self.autoencoder.time_dim) |
|
|
diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask) |
|
|
return { |
|
|
"diff_loss": diff_loss, |
|
|
"local_duration_loss": local_duration_loss, |
|
|
"global_duration_loss": global_duration_loss, |
|
|
} |
|
|
|
|
|
@torch.no_grad() |
|
|
def inference( |
|
|
self, |
|
|
content: list[Any], |
|
|
condition: list[Any], |
|
|
task: list[str], |
|
|
is_time_aligned: list[bool], |
|
|
instruction: torch.Tensor, |
|
|
instruction_lengths: Sequence[int], |
|
|
scheduler: SchedulerMixin, |
|
|
num_steps: int = 20, |
|
|
guidance_scale: float = 3.0, |
|
|
guidance_rescale: float = 0.0, |
|
|
disable_progress: bool = True, |
|
|
use_gt_duration: bool = False, |
|
|
**kwargs |
|
|
): |
|
|
device = self.dummy_param.device |
|
|
classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
content_output: dict[ |
|
|
str, torch.Tensor] = self.content_encoder.encode_content( |
|
|
content, task, device=device |
|
|
) |
|
|
length_aligned_content = content_output["length_aligned_content"] |
|
|
content, content_mask = content_output["content"], content_output[ |
|
|
"content_mask"] |
|
|
instruction_mask = create_mask_from_length(instruction_lengths) |
|
|
|
|
|
content, content_mask, global_duration_pred, local_duration_pred = \ |
|
|
self.content_adapter(content, content_mask, instruction, instruction_mask) |
|
|
|
|
|
scheduler.set_timesteps(num_steps, device=device) |
|
|
timesteps = scheduler.timesteps |
|
|
batch_size = content.size(0) |
|
|
|
|
|
|
|
|
is_time_aligned = torch.as_tensor(is_time_aligned) |
|
|
if is_time_aligned.sum() > 0: |
|
|
trunc_ta_length = content_mask[is_time_aligned].sum(1).max() |
|
|
else: |
|
|
trunc_ta_length = content.size(1) |
|
|
|
|
|
|
|
|
local_duration_pred = torch.exp(local_duration_pred) * content_mask |
|
|
local_duration_pred = torch.ceil( |
|
|
local_duration_pred |
|
|
) - self.duration_offset |
|
|
local_duration_pred = torch.round(local_duration_pred * self.frame_resolution * \ |
|
|
self.autoencoder.latent_token_rate) |
|
|
local_duration_pred = local_duration_pred[:, :trunc_ta_length] |
|
|
|
|
|
if use_gt_duration and "duration" in kwargs: |
|
|
local_duration_pred = torch.round( |
|
|
torch.as_tensor(kwargs["duration"]) * |
|
|
self.autoencoder.latent_token_rate |
|
|
).to(device) |
|
|
|
|
|
|
|
|
global_duration = local_duration_pred.sum(1) |
|
|
global_duration_pred = torch.exp( |
|
|
global_duration_pred |
|
|
) - self.duration_offset |
|
|
global_duration_pred *= self.autoencoder.latent_token_rate |
|
|
global_duration_pred = torch.round(global_duration_pred) |
|
|
global_duration[~is_time_aligned] = global_duration_pred[ |
|
|
~is_time_aligned] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time_aligned_content = content[:, :trunc_ta_length] |
|
|
ta_content_mask = content_mask[:, :trunc_ta_length] |
|
|
latent_mask = create_mask_from_length(global_duration).to( |
|
|
content_mask.device |
|
|
) |
|
|
attn_mask = ta_content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1) |
|
|
|
|
|
align_path = create_alignment_path(local_duration_pred, attn_mask) |
|
|
time_aligned_content = torch.matmul( |
|
|
align_path.transpose(1, 2).to(content.dtype), time_aligned_content |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
length_aligned_content = trim_or_pad_length( |
|
|
length_aligned_content, time_aligned_content.size(1), 1 |
|
|
) |
|
|
time_aligned_content = time_aligned_content + length_aligned_content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
context = content |
|
|
|
|
|
context_mask = content_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if classifier_free_guidance: |
|
|
uncond_time_aligned_content = torch.zeros_like( |
|
|
time_aligned_content |
|
|
) |
|
|
uncond_context = torch.zeros_like(context) |
|
|
uncond_context_mask = context_mask.detach().clone() |
|
|
time_aligned_content = torch.cat([ |
|
|
uncond_time_aligned_content, time_aligned_content |
|
|
]) |
|
|
context = torch.cat([uncond_context, context]) |
|
|
context_mask = torch.cat([uncond_context_mask, context_mask]) |
|
|
latent_mask = torch.cat([ |
|
|
latent_mask, latent_mask.detach().clone() |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent_shape = tuple( |
|
|
int(global_duration.max().item()) if dim is None else dim |
|
|
for dim in self.autoencoder.latent_shape |
|
|
) |
|
|
shape = (batch_size, *latent_shape) |
|
|
latent = randn_tensor( |
|
|
shape, generator=None, device=device, dtype=content.dtype |
|
|
) |
|
|
|
|
|
latent = latent * scheduler.init_noise_sigma |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_steps * scheduler.order |
|
|
progress_bar = tqdm(range(num_steps), disable=disable_progress) |
|
|
|
|
|
|
|
|
|
|
|
for i, timestep in enumerate(timesteps): |
|
|
|
|
|
if classifier_free_guidance: |
|
|
latent_input = torch.cat([latent, latent]) |
|
|
else: |
|
|
latent_input = latent |
|
|
|
|
|
latent_input = scheduler.scale_model_input(latent_input, timestep) |
|
|
noise_pred = self.backbone( |
|
|
x=latent_input, |
|
|
x_mask=latent_mask, |
|
|
timesteps=timestep, |
|
|
time_aligned_context=time_aligned_content, |
|
|
context=context, |
|
|
context_mask=context_mask |
|
|
) |
|
|
|
|
|
if classifier_free_guidance: |
|
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
|
noise_pred_cond - noise_pred_uncond |
|
|
) |
|
|
if guidance_rescale != 0.0: |
|
|
noise_pred = self.rescale_cfg( |
|
|
noise_pred_cond, noise_pred, guidance_rescale |
|
|
) |
|
|
|
|
|
|
|
|
latent = scheduler.step(noise_pred, timestep, latent).prev_sample |
|
|
|
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and |
|
|
(i + 1) % scheduler.order == 0): |
|
|
progress_bar.update(1) |
|
|
|
|
|
progress_bar.close() |
|
|
|
|
|
|
|
|
waveform = self.autoencoder.decode(latent) |
|
|
return waveform |
|
|
|