Spaces:
Running
on
Zero
Running
on
Zero
Unify duration scale: use latent token count for both training and inference
Browse files- models/flow_matching.py +45 -25
models/flow_matching.py
CHANGED
|
@@ -9,6 +9,7 @@ import copy
|
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
|
|
|
| 12 |
|
| 13 |
from diffusers.utils.torch_utils import randn_tensor
|
| 14 |
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
@@ -475,6 +476,7 @@ class DurationAdapterMixin:
|
|
| 475 |
pred = torch.exp(pred) * mask
|
| 476 |
pred = torch.ceil(pred) - self.offset
|
| 477 |
pred *= self.frame_resolution
|
|
|
|
| 478 |
return pred
|
| 479 |
|
| 480 |
def prepare_global_duration(
|
|
@@ -489,11 +491,10 @@ class DurationAdapterMixin:
|
|
| 489 |
local_pred: predicted latent length
|
| 490 |
"""
|
| 491 |
global_pred = torch.exp(global_pred) - self.offset
|
| 492 |
-
result = global_pred
|
| 493 |
# avoid error accumulation for each frame
|
| 494 |
if use_local:
|
| 495 |
-
pred_from_local =
|
| 496 |
-
pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
|
| 497 |
result[is_time_aligned] = pred_from_local[is_time_aligned]
|
| 498 |
|
| 499 |
return result
|
|
@@ -503,20 +504,18 @@ class DurationAdapterMixin:
|
|
| 503 |
x: torch.Tensor,
|
| 504 |
content_mask: torch.Tensor,
|
| 505 |
local_duration: torch.Tensor,
|
| 506 |
-
global_duration: torch.Tensor
|
| 507 |
):
|
| 508 |
-
|
| 509 |
-
if
|
| 510 |
-
latent_length =
|
| 511 |
-
|
| 512 |
-
)
|
| 513 |
-
else:
|
| 514 |
-
latent_length = n_latents.sum(1)
|
| 515 |
latent_mask = create_mask_from_length(latent_length).to(
|
| 516 |
content_mask.device
|
| 517 |
)
|
| 518 |
attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
|
| 519 |
-
align_path = create_alignment_path(
|
| 520 |
expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
|
| 521 |
return expanded_x, latent_mask
|
| 522 |
|
|
@@ -665,14 +664,12 @@ class CrossAttentionAudioFlowMatching(
|
|
| 665 |
)
|
| 666 |
|
| 667 |
# prepare global duration
|
| 668 |
-
|
| 669 |
global_duration_pred,
|
| 670 |
local_duration_pred,
|
| 671 |
is_time_aligned,
|
| 672 |
use_local=False
|
| 673 |
)
|
| 674 |
-
# TODO: manually set duration for SE and AudioSR
|
| 675 |
-
latent_length = torch.round(global_duration * self.latent_token_rate)
|
| 676 |
task_mask = torch.as_tensor([t in SAME_LENGTH_TASKS for t in task])
|
| 677 |
latent_length[task_mask] = content[task_mask].size(1)
|
| 678 |
latent_mask = create_mask_from_length(latent_length).to(device)
|
|
@@ -735,7 +732,8 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
|
|
| 735 |
duration_offset: float = 1.0,
|
| 736 |
cfg_drop_ratio: float = 0.2,
|
| 737 |
sample_strategy: str = 'normal',
|
| 738 |
-
num_train_steps: int = 1000
|
|
|
|
| 739 |
):
|
| 740 |
|
| 741 |
super().__init__(
|
|
@@ -758,6 +756,7 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
|
|
| 758 |
)
|
| 759 |
self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
|
| 760 |
self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
|
|
|
|
| 761 |
|
| 762 |
def get_backbone_input(
|
| 763 |
self, target_length: int, content: torch.Tensor,
|
|
@@ -808,7 +807,12 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
|
|
| 808 |
**kwargs
|
| 809 |
):
|
| 810 |
device = self.dummy_param.device
|
| 811 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
|
| 813 |
self.autoencoder.eval()
|
| 814 |
with torch.no_grad():
|
|
@@ -859,10 +863,12 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
|
|
| 859 |
duration = F.pad(
|
| 860 |
duration, (0, content_mask.size(1) - duration.size(1))
|
| 861 |
)
|
|
|
|
| 862 |
time_aligned_content, _ = self.expand_by_duration(
|
| 863 |
x=content[:, :trunc_ta_length],
|
| 864 |
content_mask=ta_content_mask,
|
| 865 |
-
local_duration=
|
|
|
|
| 866 |
)
|
| 867 |
|
| 868 |
# --------------------------------------------------------------------
|
|
@@ -899,6 +905,16 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
|
|
| 899 |
target = target.transpose(1, self.autoencoder.time_dim)
|
| 900 |
diff_loss = F.mse_loss(pred, target, reduction="none")
|
| 901 |
diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
return {
|
| 903 |
"diff_loss": diff_loss,
|
| 904 |
"local_duration_loss": local_duration_loss,
|
|
@@ -939,17 +955,21 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
|
|
| 939 |
trunc_ta_length = content.size(1)
|
| 940 |
|
| 941 |
# prepare local duration
|
| 942 |
-
|
| 943 |
local_duration_pred, content_mask
|
| 944 |
)
|
| 945 |
-
|
| 946 |
# use ground truth duration
|
| 947 |
if use_gt_duration and "duration" in kwargs:
|
| 948 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 949 |
|
| 950 |
# prepare global duration
|
| 951 |
-
|
| 952 |
-
global_duration_pred,
|
| 953 |
)
|
| 954 |
|
| 955 |
# --------------------------------------------------------------------
|
|
@@ -958,8 +978,8 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
|
|
| 958 |
time_aligned_content, latent_mask = self.expand_by_duration(
|
| 959 |
x=content[:, :trunc_ta_length],
|
| 960 |
content_mask=content_mask[:, :trunc_ta_length],
|
| 961 |
-
local_duration=
|
| 962 |
-
global_duration=
|
| 963 |
)
|
| 964 |
|
| 965 |
context, context_mask, time_aligned_content = self.get_backbone_input(
|
|
|
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
import torch.nn.functional as F
|
| 12 |
+
from torch.nn import init
|
| 13 |
|
| 14 |
from diffusers.utils.torch_utils import randn_tensor
|
| 15 |
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
|
|
| 476 |
pred = torch.exp(pred) * mask
|
| 477 |
pred = torch.ceil(pred) - self.offset
|
| 478 |
pred *= self.frame_resolution
|
| 479 |
+
pred = torch.round(pred * self.latent_token_rate)
|
| 480 |
return pred
|
| 481 |
|
| 482 |
def prepare_global_duration(
|
|
|
|
| 491 |
local_pred: predicted latent length
|
| 492 |
"""
|
| 493 |
global_pred = torch.exp(global_pred) - self.offset
|
| 494 |
+
result = torch.round(global_pred * self.latent_token_rate)
|
| 495 |
# avoid error accumulation for each frame
|
| 496 |
if use_local:
|
| 497 |
+
pred_from_local = local_pred.sum(1)
|
|
|
|
| 498 |
result[is_time_aligned] = pred_from_local[is_time_aligned]
|
| 499 |
|
| 500 |
return result
|
|
|
|
| 504 |
x: torch.Tensor,
|
| 505 |
content_mask: torch.Tensor,
|
| 506 |
local_duration: torch.Tensor,
|
| 507 |
+
global_duration: torch.Tensor,
|
| 508 |
):
|
| 509 |
+
training = getattr(self, 'training', False)
|
| 510 |
+
if not training: # inference mode
|
| 511 |
+
latent_length = global_duration
|
| 512 |
+
else: # training mode
|
| 513 |
+
latent_length = local_duration.sum(1)
|
|
|
|
|
|
|
| 514 |
latent_mask = create_mask_from_length(latent_length).to(
|
| 515 |
content_mask.device
|
| 516 |
)
|
| 517 |
attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
|
| 518 |
+
align_path = create_alignment_path(local_duration, attn_mask)
|
| 519 |
expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
|
| 520 |
return expanded_x, latent_mask
|
| 521 |
|
|
|
|
| 664 |
)
|
| 665 |
|
| 666 |
# prepare global duration
|
| 667 |
+
latent_length = self.prepare_global_duration(
|
| 668 |
global_duration_pred,
|
| 669 |
local_duration_pred,
|
| 670 |
is_time_aligned,
|
| 671 |
use_local=False
|
| 672 |
)
|
|
|
|
|
|
|
| 673 |
task_mask = torch.as_tensor([t in SAME_LENGTH_TASKS for t in task])
|
| 674 |
latent_length[task_mask] = content[task_mask].size(1)
|
| 675 |
latent_mask = create_mask_from_length(latent_length).to(device)
|
|
|
|
| 732 |
duration_offset: float = 1.0,
|
| 733 |
cfg_drop_ratio: float = 0.2,
|
| 734 |
sample_strategy: str = 'normal',
|
| 735 |
+
num_train_steps: int = 1000,
|
| 736 |
+
task_weights: dict | None = None,
|
| 737 |
):
|
| 738 |
|
| 739 |
super().__init__(
|
|
|
|
| 756 |
)
|
| 757 |
self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
|
| 758 |
self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
|
| 759 |
+
self.task_weights = task_weights
|
| 760 |
|
| 761 |
def get_backbone_input(
|
| 762 |
self, target_length: int, content: torch.Tensor,
|
|
|
|
| 807 |
**kwargs
|
| 808 |
):
|
| 809 |
device = self.dummy_param.device
|
| 810 |
+
if self.training:
|
| 811 |
+
if self.task_weights:
|
| 812 |
+
loss_reduce = False
|
| 813 |
+
else:
|
| 814 |
+
loss_reduce = True
|
| 815 |
+
# loss_reduce = self.training or (loss_reduce and not self.training)
|
| 816 |
|
| 817 |
self.autoencoder.eval()
|
| 818 |
with torch.no_grad():
|
|
|
|
| 863 |
duration = F.pad(
|
| 864 |
duration, (0, content_mask.size(1) - duration.size(1))
|
| 865 |
)
|
| 866 |
+
local_latent_duration = torch.round(duration * self.latent_token_rate)
|
| 867 |
time_aligned_content, _ = self.expand_by_duration(
|
| 868 |
x=content[:, :trunc_ta_length],
|
| 869 |
content_mask=ta_content_mask,
|
| 870 |
+
local_duration=local_latent_duration,
|
| 871 |
+
global_duration=latent_mask.sum(1),
|
| 872 |
)
|
| 873 |
|
| 874 |
# --------------------------------------------------------------------
|
|
|
|
| 905 |
target = target.transpose(1, self.autoencoder.time_dim)
|
| 906 |
diff_loss = F.mse_loss(pred, target, reduction="none")
|
| 907 |
diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
|
| 908 |
+
|
| 909 |
+
if self.training and self.task_weights:
|
| 910 |
+
loss_weights = torch.tensor([self.task_weights[t] for t in task],
|
| 911 |
+
device=device)
|
| 912 |
+
diff_loss = (diff_loss * loss_weights).sum() / loss_weights.sum()
|
| 913 |
+
local_duration_loss = (local_duration_loss *
|
| 914 |
+
loss_weights).sum() / loss_weights.sum()
|
| 915 |
+
global_duration_loss = (global_duration_loss *
|
| 916 |
+
loss_weights).sum() / loss_weights.sum()
|
| 917 |
+
|
| 918 |
return {
|
| 919 |
"diff_loss": diff_loss,
|
| 920 |
"local_duration_loss": local_duration_loss,
|
|
|
|
| 955 |
trunc_ta_length = content.size(1)
|
| 956 |
|
| 957 |
# prepare local duration
|
| 958 |
+
local_latent_duration = self.prepare_local_duration(
|
| 959 |
local_duration_pred, content_mask
|
| 960 |
)
|
| 961 |
+
local_latent_duration = local_latent_duration[:, :trunc_ta_length]
|
| 962 |
# use ground truth duration
|
| 963 |
if use_gt_duration and "duration" in kwargs:
|
| 964 |
+
local_latent_duration = torch.as_tensor(kwargs["duration"]
|
| 965 |
+
).to(device)
|
| 966 |
+
local_latent_duration = torch.round(
|
| 967 |
+
local_latent_duration * self.latent_token_rate
|
| 968 |
+
)
|
| 969 |
|
| 970 |
# prepare global duration
|
| 971 |
+
global_latent_duration = self.prepare_global_duration(
|
| 972 |
+
global_duration_pred, local_latent_duration, is_time_aligned
|
| 973 |
)
|
| 974 |
|
| 975 |
# --------------------------------------------------------------------
|
|
|
|
| 978 |
time_aligned_content, latent_mask = self.expand_by_duration(
|
| 979 |
x=content[:, :trunc_ta_length],
|
| 980 |
content_mask=content_mask[:, :trunc_ta_length],
|
| 981 |
+
local_duration=local_latent_duration,
|
| 982 |
+
global_duration=global_latent_duration,
|
| 983 |
)
|
| 984 |
|
| 985 |
context, context_mask, time_aligned_content = self.get_backbone_input(
|