|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.nn import functional as F
|
|
|
from functools import partial
|
|
|
import numpy as np
|
|
|
import typing as tp
|
|
|
|
|
|
from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
|
|
|
from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
|
|
|
from .dit import DiffusionTransformer
|
|
|
from .factory import create_pretransform_from_config
|
|
|
from .pretransforms import Pretransform
|
|
|
from ..inference.generation import generate_diffusion_cond
|
|
|
|
|
|
from .adp import UNetCFG1d, UNet1d
|
|
|
|
|
|
from time import time
|
|
|
|
|
|
class Profiler:
|
|
|
|
|
|
def __init__(self):
|
|
|
self.ticks = [[time(), None]]
|
|
|
|
|
|
def tick(self, msg):
|
|
|
self.ticks.append([time(), msg])
|
|
|
|
|
|
def __repr__(self):
|
|
|
rep = 80 * "=" + "\n"
|
|
|
for i in range(1, len(self.ticks)):
|
|
|
msg = self.ticks[i][1]
|
|
|
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
|
|
|
rep += msg + f": {ellapsed*1000:.2f}ms\n"
|
|
|
rep += 80 * "=" + "\n\n\n"
|
|
|
return rep
|
|
|
|
|
|
class DiffusionModel(nn.Module):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
def forward(self, x, t, **kwargs):
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
class DiffusionModelWrapper(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: DiffusionModel,
|
|
|
io_channels,
|
|
|
sample_size,
|
|
|
sample_rate,
|
|
|
min_input_length,
|
|
|
pretransform: tp.Optional[Pretransform] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.io_channels = io_channels
|
|
|
self.sample_size = sample_size
|
|
|
self.sample_rate = sample_rate
|
|
|
self.min_input_length = min_input_length
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
if pretransform is not None:
|
|
|
self.pretransform = pretransform
|
|
|
else:
|
|
|
self.pretransform = None
|
|
|
|
|
|
def forward(self, x, t, **kwargs):
|
|
|
return self.model(x, t, **kwargs)
|
|
|
|
|
|
class ConditionedDiffusionModel(nn.Module):
|
|
|
def __init__(self,
|
|
|
*args,
|
|
|
supports_cross_attention: bool = False,
|
|
|
supports_input_concat: bool = False,
|
|
|
supports_global_cond: bool = False,
|
|
|
supports_prepend_cond: bool = False,
|
|
|
**kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self.supports_cross_attention = supports_cross_attention
|
|
|
self.supports_input_concat = supports_input_concat
|
|
|
self.supports_global_cond = supports_global_cond
|
|
|
self.supports_prepend_cond = supports_prepend_cond
|
|
|
|
|
|
def forward(self,
|
|
|
x: torch.Tensor,
|
|
|
t: torch.Tensor,
|
|
|
cross_attn_cond: torch.Tensor = None,
|
|
|
cross_attn_mask: torch.Tensor = None,
|
|
|
input_concat_cond: torch.Tensor = None,
|
|
|
global_embed: torch.Tensor = None,
|
|
|
prepend_cond: torch.Tensor = None,
|
|
|
prepend_cond_mask: torch.Tensor = None,
|
|
|
cfg_scale: float = 1.0,
|
|
|
cfg_dropout_prob: float = 0.0,
|
|
|
batch_cfg: bool = False,
|
|
|
rescale_cfg: bool = False,
|
|
|
**kwargs):
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
class ConditionedDiffusionModelWrapper(nn.Module):
|
|
|
"""
|
|
|
A diffusion model that takes in conditioning
|
|
|
"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: ConditionedDiffusionModel,
|
|
|
conditioner: MultiConditioner,
|
|
|
io_channels,
|
|
|
sample_rate,
|
|
|
min_input_length: int,
|
|
|
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
|
|
pretransform: tp.Optional[Pretransform] = None,
|
|
|
cross_attn_cond_ids: tp.List[str] = [],
|
|
|
global_cond_ids: tp.List[str] = [],
|
|
|
input_concat_ids: tp.List[str] = [],
|
|
|
prepend_cond_ids: tp.List[str] = [],
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.model = model
|
|
|
self.conditioner = conditioner
|
|
|
self.io_channels = io_channels
|
|
|
self.sample_rate = sample_rate
|
|
|
self.diffusion_objective = diffusion_objective
|
|
|
self.pretransform = pretransform
|
|
|
self.cross_attn_cond_ids = cross_attn_cond_ids
|
|
|
self.global_cond_ids = global_cond_ids
|
|
|
self.input_concat_ids = input_concat_ids
|
|
|
self.prepend_cond_ids = prepend_cond_ids
|
|
|
self.min_input_length = min_input_length
|
|
|
|
|
|
def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[torch.Tensor, tp.Any], negative=False):
|
|
|
cross_attention_input = None
|
|
|
cross_attention_masks = None
|
|
|
global_cond = None
|
|
|
input_concat_cond = None
|
|
|
prepend_cond = None
|
|
|
prepend_cond_mask = None
|
|
|
|
|
|
if len(self.cross_attn_cond_ids) > 0:
|
|
|
|
|
|
|
|
|
cross_attention_input = []
|
|
|
cross_attention_masks = []
|
|
|
|
|
|
for key in self.cross_attn_cond_ids:
|
|
|
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
|
|
|
|
|
|
|
|
|
if len(cross_attn_in.shape) == 2:
|
|
|
cross_attn_in = cross_attn_in.unsqueeze(1)
|
|
|
cross_attn_mask = cross_attn_mask.unsqueeze(1)
|
|
|
|
|
|
cross_attention_input.append(cross_attn_in)
|
|
|
cross_attention_masks.append(cross_attn_mask)
|
|
|
|
|
|
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
|
|
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
|
|
|
|
|
if len(self.global_cond_ids) > 0:
|
|
|
|
|
|
|
|
|
global_conds = []
|
|
|
for key in self.global_cond_ids:
|
|
|
|
|
|
global_cond_input = conditioning_tensors[key][0]
|
|
|
|
|
|
global_conds.append(global_cond_input)
|
|
|
|
|
|
|
|
|
global_cond = torch.cat(global_conds, dim=-1)
|
|
|
|
|
|
if len(global_cond.shape) == 3:
|
|
|
global_cond = global_cond.squeeze(1)
|
|
|
|
|
|
if len(self.input_concat_ids) > 0:
|
|
|
|
|
|
|
|
|
input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
|
|
|
|
|
|
if len(self.prepend_cond_ids) > 0:
|
|
|
|
|
|
|
|
|
prepend_conds = []
|
|
|
prepend_cond_masks = []
|
|
|
|
|
|
for key in self.prepend_cond_ids:
|
|
|
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
|
|
|
prepend_conds.append(prepend_cond_input)
|
|
|
prepend_cond_masks.append(prepend_cond_mask)
|
|
|
|
|
|
prepend_cond = torch.cat(prepend_conds, dim=1)
|
|
|
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
|
|
|
|
|
|
if negative:
|
|
|
return {
|
|
|
"negative_cross_attn_cond": cross_attention_input,
|
|
|
"negative_cross_attn_mask": cross_attention_masks,
|
|
|
"negative_global_cond": global_cond,
|
|
|
"negative_input_concat_cond": input_concat_cond
|
|
|
}
|
|
|
else:
|
|
|
return {
|
|
|
"cross_attn_cond": cross_attention_input,
|
|
|
"cross_attn_mask": cross_attention_masks,
|
|
|
"global_cond": global_cond,
|
|
|
"input_concat_cond": input_concat_cond,
|
|
|
"prepend_cond": prepend_cond,
|
|
|
"prepend_cond_mask": prepend_cond_mask
|
|
|
}
|
|
|
|
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
|
|
|
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
|
|
|
|
|
|
def generate(self, *args, **kwargs):
|
|
|
return generate_diffusion_cond(self, *args, **kwargs)
|
|
|
|
|
|
class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
*args,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
|
|
|
|
|
|
self.model = UNetCFG1d(*args, **kwargs)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for param in self.model.parameters():
|
|
|
param *= 0.5
|
|
|
|
|
|
def forward(self,
|
|
|
x,
|
|
|
t,
|
|
|
cross_attn_cond=None,
|
|
|
cross_attn_mask=None,
|
|
|
input_concat_cond=None,
|
|
|
global_cond=None,
|
|
|
cfg_scale=1.0,
|
|
|
cfg_dropout_prob: float = 0.0,
|
|
|
batch_cfg: bool = False,
|
|
|
rescale_cfg: bool = False,
|
|
|
negative_cross_attn_cond=None,
|
|
|
negative_cross_attn_mask=None,
|
|
|
negative_global_cond=None,
|
|
|
negative_input_concat_cond=None,
|
|
|
prepend_cond=None,
|
|
|
prepend_cond_mask=None,
|
|
|
**kwargs):
|
|
|
p = Profiler()
|
|
|
|
|
|
p.tick("start")
|
|
|
|
|
|
channels_list = None
|
|
|
if input_concat_cond is not None:
|
|
|
channels_list = [input_concat_cond]
|
|
|
|
|
|
outputs = self.model(
|
|
|
x,
|
|
|
t,
|
|
|
embedding=cross_attn_cond,
|
|
|
embedding_mask=cross_attn_mask,
|
|
|
features=global_cond,
|
|
|
channels_list=channels_list,
|
|
|
embedding_scale=cfg_scale,
|
|
|
embedding_mask_proba=cfg_dropout_prob,
|
|
|
batch_cfg=batch_cfg,
|
|
|
rescale_cfg=rescale_cfg,
|
|
|
negative_embedding=negative_cross_attn_cond,
|
|
|
negative_embedding_mask=negative_cross_attn_mask,
|
|
|
**kwargs)
|
|
|
|
|
|
p.tick("UNetCFG1D forward")
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
*args,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
|
|
|
|
|
|
self.model = UNet1d(*args, **kwargs)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for param in self.model.parameters():
|
|
|
param *= 0.5
|
|
|
|
|
|
def forward(self,
|
|
|
x,
|
|
|
t,
|
|
|
input_concat_cond=None,
|
|
|
global_cond=None,
|
|
|
cross_attn_cond=None,
|
|
|
cross_attn_mask=None,
|
|
|
prepend_cond=None,
|
|
|
prepend_cond_mask=None,
|
|
|
cfg_scale=1.0,
|
|
|
cfg_dropout_prob: float = 0.0,
|
|
|
batch_cfg: bool = False,
|
|
|
rescale_cfg: bool = False,
|
|
|
negative_cross_attn_cond=None,
|
|
|
negative_cross_attn_mask=None,
|
|
|
negative_global_cond=None,
|
|
|
negative_input_concat_cond=None,
|
|
|
**kwargs):
|
|
|
|
|
|
channels_list = None
|
|
|
if input_concat_cond is not None:
|
|
|
|
|
|
|
|
|
if input_concat_cond.shape[2] != x.shape[2]:
|
|
|
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
|
|
|
|
|
|
channels_list = [input_concat_cond]
|
|
|
|
|
|
outputs = self.model(
|
|
|
x,
|
|
|
t,
|
|
|
features=global_cond,
|
|
|
channels_list=channels_list,
|
|
|
**kwargs)
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
class UNet1DUncondWrapper(DiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels,
|
|
|
*args,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
|
|
|
|
|
|
self.io_channels = in_channels
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for param in self.model.parameters():
|
|
|
param *= 0.5
|
|
|
|
|
|
def forward(self, x, t, **kwargs):
|
|
|
return self.model(x, t, **kwargs)
|
|
|
|
|
|
class DAU1DCondWrapper(ConditionedDiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
*args,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
|
|
|
|
|
|
self.model = DiffusionAttnUnet1D(*args, **kwargs)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for param in self.model.parameters():
|
|
|
param *= 0.5
|
|
|
|
|
|
def forward(self,
|
|
|
x,
|
|
|
t,
|
|
|
input_concat_cond=None,
|
|
|
cross_attn_cond=None,
|
|
|
cross_attn_mask=None,
|
|
|
global_cond=None,
|
|
|
cfg_scale=1.0,
|
|
|
cfg_dropout_prob: float = 0.0,
|
|
|
batch_cfg: bool = False,
|
|
|
rescale_cfg: bool = False,
|
|
|
negative_cross_attn_cond=None,
|
|
|
negative_cross_attn_mask=None,
|
|
|
negative_global_cond=None,
|
|
|
negative_input_concat_cond=None,
|
|
|
prepend_cond=None,
|
|
|
**kwargs):
|
|
|
|
|
|
return self.model(x, t, cond = input_concat_cond)
|
|
|
|
|
|
class DiffusionAttnUnet1D(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
io_channels = 2,
|
|
|
depth=14,
|
|
|
n_attn_layers = 6,
|
|
|
channels = [128, 128, 256, 256] + [512] * 10,
|
|
|
cond_dim = 0,
|
|
|
cond_noise_aug = False,
|
|
|
kernel_size = 5,
|
|
|
learned_resample = False,
|
|
|
strides = [2] * 13,
|
|
|
conv_bias = True,
|
|
|
use_snake = False
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.cond_noise_aug = cond_noise_aug
|
|
|
|
|
|
self.io_channels = io_channels
|
|
|
|
|
|
if self.cond_noise_aug:
|
|
|
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
|
|
|
|
|
self.timestep_embed = FourierFeatures(1, 16)
|
|
|
|
|
|
attn_layer = depth - n_attn_layers
|
|
|
|
|
|
strides = [1] + strides
|
|
|
|
|
|
block = nn.Identity()
|
|
|
|
|
|
conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
|
|
|
|
|
|
for i in range(depth, 0, -1):
|
|
|
c = channels[i - 1]
|
|
|
stride = strides[i-1]
|
|
|
if stride > 2 and not learned_resample:
|
|
|
raise ValueError("Must have stride 2 without learned resampling")
|
|
|
|
|
|
if i > 1:
|
|
|
c_prev = channels[i - 2]
|
|
|
add_attn = i >= attn_layer and n_attn_layers > 0
|
|
|
block = SkipBlock(
|
|
|
Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
|
|
|
conv_block(c_prev, c, c),
|
|
|
SelfAttention1d(
|
|
|
c, c // 32) if add_attn else nn.Identity(),
|
|
|
conv_block(c, c, c),
|
|
|
SelfAttention1d(
|
|
|
c, c // 32) if add_attn else nn.Identity(),
|
|
|
conv_block(c, c, c),
|
|
|
SelfAttention1d(
|
|
|
c, c // 32) if add_attn else nn.Identity(),
|
|
|
block,
|
|
|
conv_block(c * 2 if i != depth else c, c, c),
|
|
|
SelfAttention1d(
|
|
|
c, c // 32) if add_attn else nn.Identity(),
|
|
|
conv_block(c, c, c),
|
|
|
SelfAttention1d(
|
|
|
c, c // 32) if add_attn else nn.Identity(),
|
|
|
conv_block(c, c, c_prev),
|
|
|
SelfAttention1d(c_prev, c_prev //
|
|
|
32) if add_attn else nn.Identity(),
|
|
|
Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
|
|
|
)
|
|
|
else:
|
|
|
cond_embed_dim = 16 if not self.cond_noise_aug else 32
|
|
|
block = nn.Sequential(
|
|
|
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
|
|
|
conv_block(c, c, c),
|
|
|
conv_block(c, c, c),
|
|
|
block,
|
|
|
conv_block(c * 2, c, c),
|
|
|
conv_block(c, c, c),
|
|
|
conv_block(c, c, io_channels, is_last=True),
|
|
|
)
|
|
|
self.net = block
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for param in self.net.parameters():
|
|
|
param *= 0.5
|
|
|
|
|
|
def forward(self, x, t, cond=None, cond_aug_scale=None):
|
|
|
|
|
|
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
|
|
|
|
|
|
inputs = [x, timestep_embed]
|
|
|
|
|
|
if cond is not None:
|
|
|
if cond.shape[2] != x.shape[2]:
|
|
|
cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
|
|
|
|
|
|
if self.cond_noise_aug:
|
|
|
|
|
|
if cond_aug_scale is None:
|
|
|
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
|
|
|
else:
|
|
|
aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
|
|
|
|
|
|
|
|
|
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
|
|
|
|
|
|
|
|
|
aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
|
|
|
|
|
|
inputs.append(aug_level_embed)
|
|
|
|
|
|
inputs.append(cond)
|
|
|
|
|
|
outputs = self.net(torch.cat(inputs, dim=1))
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
class DiTWrapper(ConditionedDiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
*args,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
|
|
|
|
|
|
self.model = DiffusionTransformer(*args, **kwargs)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for param in self.model.parameters():
|
|
|
param *= 0.5
|
|
|
|
|
|
def forward(self,
|
|
|
x,
|
|
|
t,
|
|
|
cross_attn_cond=None,
|
|
|
cross_attn_mask=None,
|
|
|
negative_cross_attn_cond=None,
|
|
|
negative_cross_attn_mask=None,
|
|
|
input_concat_cond=None,
|
|
|
negative_input_concat_cond=None,
|
|
|
global_cond=None,
|
|
|
negative_global_cond=None,
|
|
|
prepend_cond=None,
|
|
|
prepend_cond_mask=None,
|
|
|
cfg_scale=1.0,
|
|
|
cfg_dropout_prob: float = 0.0,
|
|
|
batch_cfg: bool = True,
|
|
|
rescale_cfg: bool = False,
|
|
|
scale_phi: float = 0.0,
|
|
|
**kwargs):
|
|
|
|
|
|
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
|
|
|
|
|
|
|
|
return self.model(
|
|
|
x,
|
|
|
t,
|
|
|
cross_attn_cond=cross_attn_cond,
|
|
|
cross_attn_cond_mask=cross_attn_mask,
|
|
|
negative_cross_attn_cond=negative_cross_attn_cond,
|
|
|
negative_cross_attn_mask=negative_cross_attn_mask,
|
|
|
input_concat_cond=input_concat_cond,
|
|
|
prepend_cond=prepend_cond,
|
|
|
prepend_cond_mask=prepend_cond_mask,
|
|
|
cfg_scale=cfg_scale,
|
|
|
cfg_dropout_prob=cfg_dropout_prob,
|
|
|
scale_phi=scale_phi,
|
|
|
global_embed=global_cond,
|
|
|
**kwargs)
|
|
|
|
|
|
class DiTUncondWrapper(DiffusionModel):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_channels,
|
|
|
*args,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
|
|
|
|
|
|
self.io_channels = in_channels
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for param in self.model.parameters():
|
|
|
param *= 0.5
|
|
|
|
|
|
def forward(self, x, t, **kwargs):
|
|
|
return self.model(x, t, **kwargs)
|
|
|
|
|
|
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
|
|
|
diffusion_uncond_config = config["model"]
|
|
|
|
|
|
model_type = diffusion_uncond_config.get('type', None)
|
|
|
|
|
|
diffusion_config = diffusion_uncond_config.get('config', {})
|
|
|
|
|
|
assert model_type is not None, "Must specify model type in config"
|
|
|
|
|
|
pretransform = diffusion_uncond_config.get("pretransform", None)
|
|
|
|
|
|
sample_size = config.get("sample_size", None)
|
|
|
assert sample_size is not None, "Must specify sample size in config"
|
|
|
|
|
|
sample_rate = config.get("sample_rate", None)
|
|
|
assert sample_rate is not None, "Must specify sample rate in config"
|
|
|
|
|
|
if pretransform is not None:
|
|
|
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
|
|
min_input_length = pretransform.downsampling_ratio
|
|
|
else:
|
|
|
min_input_length = 1
|
|
|
|
|
|
if model_type == 'DAU1d':
|
|
|
|
|
|
model = DiffusionAttnUnet1D(
|
|
|
**diffusion_config
|
|
|
)
|
|
|
|
|
|
elif model_type == "adp_uncond_1d":
|
|
|
|
|
|
model = UNet1DUncondWrapper(
|
|
|
**diffusion_config
|
|
|
)
|
|
|
|
|
|
elif model_type == "dit":
|
|
|
model = DiTUncondWrapper(
|
|
|
**diffusion_config
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
raise NotImplementedError(f'Unknown model type: {model_type}')
|
|
|
|
|
|
return DiffusionModelWrapper(model,
|
|
|
io_channels=model.io_channels,
|
|
|
sample_size=sample_size,
|
|
|
sample_rate=sample_rate,
|
|
|
pretransform=pretransform,
|
|
|
min_input_length=min_input_length)
|
|
|
|
|
|
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
|
|
|
|
|
model_config = config["model"]
|
|
|
|
|
|
model_type = config["model_type"]
|
|
|
|
|
|
diffusion_config = model_config.get('diffusion', None)
|
|
|
assert diffusion_config is not None, "Must specify diffusion config"
|
|
|
|
|
|
diffusion_model_type = diffusion_config.get('type', None)
|
|
|
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
|
|
|
|
|
diffusion_model_config = diffusion_config.get('config', None)
|
|
|
if diffusion_model_config.get('video_fps', None) is not None:
|
|
|
diffusion_model_config.pop('video_fps')
|
|
|
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
|
|
|
|
|
if diffusion_model_type == 'adp_cfg_1d':
|
|
|
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
|
|
elif diffusion_model_type == 'adp_1d':
|
|
|
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
|
|
elif diffusion_model_type == 'dit':
|
|
|
diffusion_model = DiTWrapper(**diffusion_model_config)
|
|
|
|
|
|
io_channels = model_config.get('io_channels', None)
|
|
|
assert io_channels is not None, "Must specify io_channels in model config"
|
|
|
|
|
|
sample_rate = config.get('sample_rate', None)
|
|
|
assert sample_rate is not None, "Must specify sample_rate in config"
|
|
|
|
|
|
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
|
|
|
|
|
|
conditioning_config = model_config.get('conditioning', None)
|
|
|
|
|
|
conditioner = None
|
|
|
if conditioning_config is not None:
|
|
|
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
|
|
|
|
|
|
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
|
|
|
global_cond_ids = diffusion_config.get('global_cond_ids', [])
|
|
|
input_concat_ids = diffusion_config.get('input_concat_ids', [])
|
|
|
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
|
|
|
|
|
|
pretransform = model_config.get("pretransform", None)
|
|
|
|
|
|
if pretransform is not None:
|
|
|
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
|
|
min_input_length = pretransform.downsampling_ratio
|
|
|
else:
|
|
|
min_input_length = 1
|
|
|
|
|
|
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
|
|
min_input_length *= np.prod(diffusion_model_config["factors"])
|
|
|
elif diffusion_model_type == "dit":
|
|
|
min_input_length *= diffusion_model.model.patch_size
|
|
|
|
|
|
|
|
|
|
|
|
extra_kwargs = {}
|
|
|
|
|
|
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
|
|
|
wrapper_fn = ConditionedDiffusionModelWrapper
|
|
|
|
|
|
extra_kwargs["diffusion_objective"] = diffusion_objective
|
|
|
|
|
|
elif model_type == "diffusion_prior":
|
|
|
prior_type = model_config.get("prior_type", None)
|
|
|
assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
|
|
|
|
|
|
if prior_type == "mono_stereo":
|
|
|
from .diffusion_prior import MonoToStereoDiffusionPrior
|
|
|
wrapper_fn = MonoToStereoDiffusionPrior
|
|
|
|
|
|
return wrapper_fn(
|
|
|
diffusion_model,
|
|
|
conditioner,
|
|
|
min_input_length=min_input_length,
|
|
|
sample_rate=sample_rate,
|
|
|
cross_attn_cond_ids=cross_attention_ids,
|
|
|
global_cond_ids=global_cond_ids,
|
|
|
input_concat_ids=input_concat_ids,
|
|
|
prepend_cond_ids=prepend_cond_ids,
|
|
|
pretransform=pretransform,
|
|
|
io_channels=io_channels,
|
|
|
**extra_kwargs
|
|
|
) |