wsntxxn commited on
Commit
51bd0e0
·
1 Parent(s): 2d1110f

Unify duration scale: use latent token count for both training and inference

Browse files
Files changed (1) hide show
  1. models/flow_matching.py +45 -25
models/flow_matching.py CHANGED
@@ -9,6 +9,7 @@ import copy
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
 
12
 
13
  from diffusers.utils.torch_utils import randn_tensor
14
  from diffusers import FlowMatchEulerDiscreteScheduler
@@ -475,6 +476,7 @@ class DurationAdapterMixin:
475
  pred = torch.exp(pred) * mask
476
  pred = torch.ceil(pred) - self.offset
477
  pred *= self.frame_resolution
 
478
  return pred
479
 
480
  def prepare_global_duration(
@@ -489,11 +491,10 @@ class DurationAdapterMixin:
489
  local_pred: predicted latent length
490
  """
491
  global_pred = torch.exp(global_pred) - self.offset
492
- result = global_pred
493
  # avoid error accumulation for each frame
494
  if use_local:
495
- pred_from_local = torch.round(local_pred * self.latent_token_rate)
496
- pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
497
  result[is_time_aligned] = pred_from_local[is_time_aligned]
498
 
499
  return result
@@ -503,20 +504,18 @@ class DurationAdapterMixin:
503
  x: torch.Tensor,
504
  content_mask: torch.Tensor,
505
  local_duration: torch.Tensor,
506
- global_duration: torch.Tensor | None = None,
507
  ):
508
- n_latents = torch.round(local_duration * self.latent_token_rate)
509
- if global_duration is not None:
510
- latent_length = torch.round(
511
- global_duration * self.latent_token_rate
512
- )
513
- else:
514
- latent_length = n_latents.sum(1)
515
  latent_mask = create_mask_from_length(latent_length).to(
516
  content_mask.device
517
  )
518
  attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
519
- align_path = create_alignment_path(n_latents, attn_mask)
520
  expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
521
  return expanded_x, latent_mask
522
 
@@ -665,14 +664,12 @@ class CrossAttentionAudioFlowMatching(
665
  )
666
 
667
  # prepare global duration
668
- global_duration = self.prepare_global_duration(
669
  global_duration_pred,
670
  local_duration_pred,
671
  is_time_aligned,
672
  use_local=False
673
  )
674
- # TODO: manually set duration for SE and AudioSR
675
- latent_length = torch.round(global_duration * self.latent_token_rate)
676
  task_mask = torch.as_tensor([t in SAME_LENGTH_TASKS for t in task])
677
  latent_length[task_mask] = content[task_mask].size(1)
678
  latent_mask = create_mask_from_length(latent_length).to(device)
@@ -735,7 +732,8 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
735
  duration_offset: float = 1.0,
736
  cfg_drop_ratio: float = 0.2,
737
  sample_strategy: str = 'normal',
738
- num_train_steps: int = 1000
 
739
  ):
740
 
741
  super().__init__(
@@ -758,6 +756,7 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
758
  )
759
  self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
760
  self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
 
761
 
762
  def get_backbone_input(
763
  self, target_length: int, content: torch.Tensor,
@@ -808,7 +807,12 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
808
  **kwargs
809
  ):
810
  device = self.dummy_param.device
811
- loss_reduce = self.training or (loss_reduce and not self.training)
 
 
 
 
 
812
 
813
  self.autoencoder.eval()
814
  with torch.no_grad():
@@ -859,10 +863,12 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
859
  duration = F.pad(
860
  duration, (0, content_mask.size(1) - duration.size(1))
861
  )
 
862
  time_aligned_content, _ = self.expand_by_duration(
863
  x=content[:, :trunc_ta_length],
864
  content_mask=ta_content_mask,
865
- local_duration=duration,
 
866
  )
867
 
868
  # --------------------------------------------------------------------
@@ -899,6 +905,16 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
899
  target = target.transpose(1, self.autoencoder.time_dim)
900
  diff_loss = F.mse_loss(pred, target, reduction="none")
901
  diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
 
 
 
 
 
 
 
 
 
 
902
  return {
903
  "diff_loss": diff_loss,
904
  "local_duration_loss": local_duration_loss,
@@ -939,17 +955,21 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
939
  trunc_ta_length = content.size(1)
940
 
941
  # prepare local duration
942
- local_duration = self.prepare_local_duration(
943
  local_duration_pred, content_mask
944
  )
945
- local_duration = local_duration[:, :trunc_ta_length]
946
  # use ground truth duration
947
  if use_gt_duration and "duration" in kwargs:
948
- local_duration = torch.as_tensor(kwargs["duration"]).to(device)
 
 
 
 
949
 
950
  # prepare global duration
951
- global_duration = self.prepare_global_duration(
952
- global_duration_pred, local_duration, is_time_aligned
953
  )
954
 
955
  # --------------------------------------------------------------------
@@ -958,8 +978,8 @@ class DummyContentAudioFlowMatching(CrossAttentionAudioFlowMatching):
958
  time_aligned_content, latent_mask = self.expand_by_duration(
959
  x=content[:, :trunc_ta_length],
960
  content_mask=content_mask[:, :trunc_ta_length],
961
- local_duration=local_duration,
962
- global_duration=global_duration,
963
  )
964
 
965
  context, context_mask, time_aligned_content = self.get_backbone_input(
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from torch.nn import init
13
 
14
  from diffusers.utils.torch_utils import randn_tensor
15
  from diffusers import FlowMatchEulerDiscreteScheduler
 
476
  pred = torch.exp(pred) * mask
477
  pred = torch.ceil(pred) - self.offset
478
  pred *= self.frame_resolution
479
+ pred = torch.round(pred * self.latent_token_rate)
480
  return pred
481
 
482
  def prepare_global_duration(
 
491
  local_pred: predicted latent length
492
  """
493
  global_pred = torch.exp(global_pred) - self.offset
494
+ result = torch.round(global_pred * self.latent_token_rate)
495
  # avoid error accumulation for each frame
496
  if use_local:
497
+ pred_from_local = local_pred.sum(1)
 
498
  result[is_time_aligned] = pred_from_local[is_time_aligned]
499
 
500
  return result
 
504
  x: torch.Tensor,
505
  content_mask: torch.Tensor,
506
  local_duration: torch.Tensor,
507
+ global_duration: torch.Tensor,
508
  ):
509
+ training = getattr(self, 'training', False)
510
+ if not training: # inference mode
511
+ latent_length = global_duration
512
+ else: # training mode
513
+ latent_length = local_duration.sum(1)
 
 
514
  latent_mask = create_mask_from_length(latent_length).to(
515
  content_mask.device
516
  )
517
  attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
518
+ align_path = create_alignment_path(local_duration, attn_mask)
519
  expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
520
  return expanded_x, latent_mask
521
 
 
664
  )
665
 
666
  # prepare global duration
667
+ latent_length = self.prepare_global_duration(
668
  global_duration_pred,
669
  local_duration_pred,
670
  is_time_aligned,
671
  use_local=False
672
  )
 
 
673
  task_mask = torch.as_tensor([t in SAME_LENGTH_TASKS for t in task])
674
  latent_length[task_mask] = content[task_mask].size(1)
675
  latent_mask = create_mask_from_length(latent_length).to(device)
 
732
  duration_offset: float = 1.0,
733
  cfg_drop_ratio: float = 0.2,
734
  sample_strategy: str = 'normal',
735
+ num_train_steps: int = 1000,
736
+ task_weights: dict | None = None,
737
  ):
738
 
739
  super().__init__(
 
756
  )
757
  self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
758
  self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))
759
+ self.task_weights = task_weights
760
 
761
  def get_backbone_input(
762
  self, target_length: int, content: torch.Tensor,
 
807
  **kwargs
808
  ):
809
  device = self.dummy_param.device
810
+ if self.training:
811
+ if self.task_weights:
812
+ loss_reduce = False
813
+ else:
814
+ loss_reduce = True
815
+ # loss_reduce = self.training or (loss_reduce and not self.training)
816
 
817
  self.autoencoder.eval()
818
  with torch.no_grad():
 
863
  duration = F.pad(
864
  duration, (0, content_mask.size(1) - duration.size(1))
865
  )
866
+ local_latent_duration = torch.round(duration * self.latent_token_rate)
867
  time_aligned_content, _ = self.expand_by_duration(
868
  x=content[:, :trunc_ta_length],
869
  content_mask=ta_content_mask,
870
+ local_duration=local_latent_duration,
871
+ global_duration=latent_mask.sum(1),
872
  )
873
 
874
  # --------------------------------------------------------------------
 
905
  target = target.transpose(1, self.autoencoder.time_dim)
906
  diff_loss = F.mse_loss(pred, target, reduction="none")
907
  diff_loss = loss_with_mask(diff_loss, latent_mask, reduce=loss_reduce)
908
+
909
+ if self.training and self.task_weights:
910
+ loss_weights = torch.tensor([self.task_weights[t] for t in task],
911
+ device=device)
912
+ diff_loss = (diff_loss * loss_weights).sum() / loss_weights.sum()
913
+ local_duration_loss = (local_duration_loss *
914
+ loss_weights).sum() / loss_weights.sum()
915
+ global_duration_loss = (global_duration_loss *
916
+ loss_weights).sum() / loss_weights.sum()
917
+
918
  return {
919
  "diff_loss": diff_loss,
920
  "local_duration_loss": local_duration_loss,
 
955
  trunc_ta_length = content.size(1)
956
 
957
  # prepare local duration
958
+ local_latent_duration = self.prepare_local_duration(
959
  local_duration_pred, content_mask
960
  )
961
+ local_latent_duration = local_latent_duration[:, :trunc_ta_length]
962
  # use ground truth duration
963
  if use_gt_duration and "duration" in kwargs:
964
+ local_latent_duration = torch.as_tensor(kwargs["duration"]
965
+ ).to(device)
966
+ local_latent_duration = torch.round(
967
+ local_latent_duration * self.latent_token_rate
968
+ )
969
 
970
  # prepare global duration
971
+ global_latent_duration = self.prepare_global_duration(
972
+ global_duration_pred, local_latent_duration, is_time_aligned
973
  )
974
 
975
  # --------------------------------------------------------------------
 
978
  time_aligned_content, latent_mask = self.expand_by_duration(
979
  x=content[:, :trunc_ta_length],
980
  content_mask=content_mask[:, :trunc_ta_length],
981
+ local_duration=local_latent_duration,
982
+ global_duration=global_latent_duration,
983
  )
984
 
985
  context, context_mask, time_aligned_content = self.get_backbone_input(