|
|
import math |
|
|
import time |
|
|
import torch |
|
|
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor |
|
|
from torch.nn import Module |
|
|
import torch.nn.functional as F |
|
|
import torchode |
|
|
from torchdiffeq import odeint |
|
|
|
|
|
from beartype import beartype |
|
|
from beartype.typing import Tuple, Optional, List, Union |
|
|
|
|
|
from einops.layers.torch import Rearrange |
|
|
from einops import rearrange, repeat, reduce, pack, unpack |
|
|
|
|
|
from modules.audio2motion.cfm.utils import * |
|
|
from modules.audio2motion.cfm.icl_transformer import InContextTransformerAudio2Motion |
|
|
|
|
|
|
|
|
|
|
|
def is_probably_audio_from_shape(t): |
|
|
return exists(t) and (t.ndim == 2 or (t.ndim == 3 and t.shape[1] == 1)) |
|
|
|
|
|
|
|
|
class ConditionalFlowMatcherWrapper(Module): |
|
|
@beartype |
|
|
def __init__( |
|
|
self, |
|
|
icl_transformer_model: InContextTransformerAudio2Motion = None, |
|
|
sigma = 0., |
|
|
ode_atol = 1e-5, |
|
|
ode_rtol = 1e-5, |
|
|
|
|
|
use_torchode = False, |
|
|
torchdiffeq_ode_method = 'midpoint', |
|
|
torchode_method_klass = torchode.Tsit5, |
|
|
cond_drop_prob = 0. |
|
|
): |
|
|
super().__init__() |
|
|
self.sigma = sigma |
|
|
if icl_transformer_model is None: |
|
|
icl_transformer_model = InContextTransformerAudio2Motion() |
|
|
self.icl_transformer_model = icl_transformer_model |
|
|
self.cond_drop_prob = cond_drop_prob |
|
|
self.use_torchode = use_torchode |
|
|
self.torchode_method_klass = torchode_method_klass |
|
|
self.odeint_kwargs = dict( |
|
|
atol = ode_atol, |
|
|
rtol = ode_rtol, |
|
|
method = torchdiffeq_ode_method, |
|
|
|
|
|
) |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return next(self.parameters()).device |
|
|
|
|
|
@torch.inference_mode() |
|
|
def sample( |
|
|
self, |
|
|
*, |
|
|
cond_audio = None, |
|
|
cond = None, |
|
|
cond_mask = None, |
|
|
steps = 3, |
|
|
cond_scale = 1., |
|
|
ret=None, |
|
|
self_attn_mask = None, |
|
|
temperature=1.0, |
|
|
): |
|
|
if ret is None: |
|
|
ret = {} |
|
|
cond_target_length = cond_audio.shape[1] // 2 |
|
|
if exists(cond): |
|
|
cond = curtail_or_pad(cond, cond_target_length) |
|
|
else: |
|
|
cond = torch.zeros((cond_audio.shape[0], cond_target_length, self.dim_cond_emb), device = self.device) |
|
|
|
|
|
shape = cond.shape |
|
|
batch = shape[0] |
|
|
|
|
|
|
|
|
|
|
|
self.icl_transformer_model.eval() |
|
|
|
|
|
def fn(t, x, *, packed_shape = None): |
|
|
if exists(packed_shape): |
|
|
x = unpack_one(x, packed_shape, 'b *') |
|
|
|
|
|
out = self.icl_transformer_model.forward_with_cond_scale( |
|
|
x, |
|
|
times = t, |
|
|
cond_audio = cond_audio, |
|
|
cond = cond, |
|
|
cond_scale = cond_scale, |
|
|
cond_mask = cond_mask, |
|
|
self_attn_mask = self_attn_mask, |
|
|
ret=ret, |
|
|
) |
|
|
|
|
|
if exists(packed_shape): |
|
|
out = rearrange(out, 'b ... -> b (...)') |
|
|
|
|
|
return out |
|
|
|
|
|
y0 = torch.randn_like(cond) * float(temperature) |
|
|
t = torch.linspace(0, 1, steps, device = self.device) |
|
|
timestamp_before_sampling = time.time() |
|
|
if not self.use_torchode: |
|
|
print(f'sampling based on torchdiffeq with flow total_steps={steps}') |
|
|
|
|
|
trajectory = odeint(fn, y0, t, **self.odeint_kwargs) |
|
|
sampled = trajectory[-1] |
|
|
else: |
|
|
print(f'sampling based on torchode with flow total_steps={steps}') |
|
|
|
|
|
t = repeat(t, 'n -> b n', b = batch) |
|
|
y0, packed_shape = pack_one(y0, 'b *') |
|
|
|
|
|
fn = partial(fn, packed_shape = packed_shape) |
|
|
|
|
|
term = to.ODETerm(fn) |
|
|
step_method = self.torchode_method_klass(term = term) |
|
|
|
|
|
step_size_controller = to.IntegralController( |
|
|
atol = self.odeint_kwargs['atol'], |
|
|
rtol = self.odeint_kwargs['rtol'], |
|
|
term = term |
|
|
) |
|
|
|
|
|
solver = to.AutoDiffAdjoint(step_method, step_size_controller) |
|
|
jit_solver = torch.compile(solver) |
|
|
|
|
|
init_value = to.InitialValueProblem(y0 = y0, t_eval = t) |
|
|
|
|
|
sol = jit_solver.solve(init_value) |
|
|
|
|
|
sampled = sol.ys[:, -1] |
|
|
sampled = unpack_one(sampled, packed_shape, 'b *') |
|
|
|
|
|
print(f"Flow matching sampling process elapsed in {time.time()-timestamp_before_sampling:.4f} second") |
|
|
return sampled |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x1, |
|
|
*, |
|
|
mask = None, |
|
|
cond_audio = None, |
|
|
cond = None, |
|
|
cond_mask = None, |
|
|
ret = None, |
|
|
): |
|
|
""" |
|
|
training step of Continous Normalizing Flow |
|
|
following eq (5) (6) in https://arxiv.org/pdf/2306.15687.pdf |
|
|
""" |
|
|
if ret is None: |
|
|
ret = {} |
|
|
batch, seq_len, dtype, sigma_ = *x1.shape[:2], x1.dtype, self.sigma |
|
|
|
|
|
|
|
|
|
|
|
x0 = torch.randn_like(x1) |
|
|
|
|
|
times = torch.rand((batch,), dtype = dtype, device = self.device) |
|
|
t = rearrange(times, 'b -> b 1 1') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_position_in_flows = (1 - (1 - sigma_) * t) * x0 + t * x1 |
|
|
optimal_path = x1 - (1 - sigma_) * x0 |
|
|
|
|
|
|
|
|
self.icl_transformer_model.train() |
|
|
|
|
|
loss = self.icl_transformer_model( |
|
|
current_position_in_flows, |
|
|
cond = cond, |
|
|
cond_mask = cond_mask, |
|
|
times = times, |
|
|
target = optimal_path, |
|
|
self_attn_mask = mask, |
|
|
cond_audio = cond_audio, |
|
|
cond_drop_prob = self.cond_drop_prob, |
|
|
ret=ret, |
|
|
) |
|
|
|
|
|
pred_x1_minus_x0 = ret['pred'] |
|
|
pred_x1 = pred_x1_minus_x0 + (1 - sigma_) * x0 |
|
|
ret['pred'] = pred_x1 |
|
|
return loss |
|
|
|
|
|
if __name__ == '__main__': |
|
|
icl_transformer = InContextTransformerAudio2Motion() |
|
|
model = ConditionalFlowMatcherWrapper(icl_transformer) |
|
|
x = torch.randn([2, 125, 64]) |
|
|
cond = torch.randn([2, 125, 64]) |
|
|
cond_audio = torch.randn([2, 250, 1024]) |
|
|
y = model(x, cond=cond, cond_audio=cond_audio) |
|
|
y = model.sample(cond=cond, cond_audio=cond_audio) |
|
|
print(y.shape) |