keylxiao commited on
Commit
5ab4485
·
1 Parent(s): 745e016

feat :sparkles: : add lyrics alignment scores

Browse files
acestep/dit_alignment_score.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import numpy as np
12
  import torch.nn.functional as F
13
  from dataclasses import dataclass, asdict
14
- from typing import List, Dict, Any, Optional
15
 
16
 
17
  # ================= Data Classes =================
@@ -545,3 +545,326 @@ class MusicStampsAligner:
545
  "lrc_text": lrc_text
546
  }
547
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import numpy as np
12
  import torch.nn.functional as F
13
  from dataclasses import dataclass, asdict
14
+ from typing import List, Dict, Any, Optional, Tuple, Union
15
 
16
 
17
  # ================= Data Classes =================
 
545
  "lrc_text": lrc_text
546
  }
547
 
548
+
549
+ class MusicLyricScorer:
550
+ """
551
+ Scorer class for evaluating lyrics-to-audio alignment quality.
552
+
553
+ Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
554
+ using tensor operations for potential differentiability or GPU acceleration.
555
+ """
556
+
557
+ def __init__(self, tokenizer: Any):
558
+ """
559
+ Initialize the aligner.
560
+
561
+ Args:
562
+ tokenizer: Tokenizer instance (must implement .decode()).
563
+ """
564
+ self.tokenizer = tokenizer
565
+
566
+ def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
567
+ """
568
+ Generate a mask distinguishing lyrics (1) from structural tags (0).
569
+ Uses self.tokenizer to decode tokens.
570
+
571
+ Args:
572
+ token_ids: List of token IDs.
573
+
574
+ Returns:
575
+ Numpy array of shape [len(token_ids)] with 1 or 0.
576
+ """
577
+ decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
578
+ mask = np.ones(len(token_ids), dtype=np.int32)
579
+ in_bracket = False
580
+
581
+ for i, token_str in enumerate(decoded_tokens):
582
+ if '[' in token_str:
583
+ in_bracket = True
584
+ if in_bracket:
585
+ mask[i] = 0
586
+ if ']' in token_str:
587
+ in_bracket = False
588
+ mask[i] = 0
589
+ return mask
590
+
591
+ def _preprocess_attention(
592
+ self,
593
+ attention_matrix: Union[torch.Tensor, np.ndarray],
594
+ custom_config: Dict[int, List[int]],
595
+ medfilt_width: int = 1
596
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
597
+ """
598
+ Extracts and normalizes the attention matrix.
599
+
600
+ Logic V4: Uses Min-Max normalization to highlight energy differences.
601
+
602
+ Args:
603
+ attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
604
+ custom_config: Config mapping layers to heads.
605
+ medfilt_width: Width for median filtering.
606
+
607
+ Returns:
608
+ Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
609
+ """
610
+ # 1. Prepare Tensor
611
+ if not isinstance(attention_matrix, torch.Tensor):
612
+ weights = torch.tensor(attention_matrix)
613
+ else:
614
+ weights = attention_matrix.clone()
615
+ weights = weights.cpu().float()
616
+
617
+ # 2. Select Heads based on config
618
+ selected_tensors = []
619
+ for layer_idx, head_indices in custom_config.items():
620
+ for head_idx in head_indices:
621
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
622
+ selected_tensors.append(weights[layer_idx, head_idx])
623
+
624
+ if not selected_tensors:
625
+ return None, None, None
626
+
627
+ weights_stack = torch.stack(selected_tensors, dim=0)
628
+
629
+ # 3. Average Heads
630
+ avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
631
+
632
+ # 4. Preprocessing Logic
633
+ # Min-Max normalization preserving energy distribution
634
+ # Median filter is applied to the energy matrix
635
+ energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
636
+ energy_matrix = energy_tensor.numpy()
637
+
638
+ e_min, e_max = energy_matrix.min(), energy_matrix.max()
639
+
640
+ if e_max - e_min > 1e-9:
641
+ energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
642
+ else:
643
+ energy_matrix = np.zeros_like(energy_matrix)
644
+
645
+ # Contrast enhancement for DTW pathfinding
646
+ # calc_matrix is used for pathfinding, energy_matrix for scoring
647
+ calc_matrix = energy_matrix ** 2
648
+
649
+ return calc_matrix, energy_matrix, avg_weights
650
+
651
+ def _compute_alignment_metrics(
652
+ self,
653
+ energy_matrix: torch.Tensor,
654
+ path_coords: torch.Tensor,
655
+ type_mask: torch.Tensor,
656
+ time_weight: float = 0.01,
657
+ overlap_frames: float = 9.0,
658
+ instrumental_weight: float = 1.0
659
+ ) -> Tuple[float, float, float]:
660
+ """
661
+ Core metric calculation logic using high-precision Tensor operations.
662
+
663
+ Args:
664
+ energy_matrix: Normalized energy [Rows, Cols].
665
+ path_coords: DTW path coordinates [Steps, 2].
666
+ type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
667
+ time_weight: Minimum energy threshold for monotonicity.
668
+ overlap_frames: Allowed overlap for monotonicity check.
669
+ instrumental_weight: Weight for non-lyric tokens in confidence calc.
670
+
671
+ Returns:
672
+ Tuple of (coverage, monotonicity, confidence).
673
+ """
674
+ # Ensure high precision for internal calculation
675
+ energy_matrix = energy_matrix.to(dtype=torch.float64)
676
+ path_coords = path_coords.long()
677
+ type_mask = type_mask.long()
678
+
679
+ device = energy_matrix.device
680
+ rows, cols = energy_matrix.shape
681
+
682
+ is_lyrics_row = (type_mask == 1)
683
+
684
+ # ================= A. Coverage Score =================
685
+ # Ratio of lyric lines that have significant energy peak
686
+ row_max_energies = energy_matrix.max(dim=1).values
687
+ total_sung_rows = is_lyrics_row.sum().double()
688
+
689
+ coverage_threshold = 0.1
690
+ valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
691
+ valid_sung_rows = valid_sung_mask.sum().double()
692
+
693
+ if total_sung_rows > 0:
694
+ coverage_score = valid_sung_rows / total_sung_rows
695
+ else:
696
+ coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
697
+
698
+ # ================= B. Monotonicity Score =================
699
+ # Check if the "center of mass" of lyric lines moves forward in time
700
+ col_indices = torch.arange(cols, device=device, dtype=torch.float64)
701
+
702
+ # Zero out low energy noise
703
+ weights = torch.where(
704
+ energy_matrix > time_weight,
705
+ energy_matrix,
706
+ torch.zeros_like(energy_matrix)
707
+ )
708
+
709
+ sum_w = weights.sum(dim=1)
710
+ sum_t = (weights * col_indices).sum(dim=1)
711
+
712
+ # Calculate centroids
713
+ centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
714
+ valid_w_mask = sum_w > 1e-9
715
+ centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
716
+
717
+ # Extract sequence of valid lyrics centroids
718
+ valid_sequence_mask = is_lyrics_row & (centroids >= 0)
719
+ sung_centroids = centroids[valid_sequence_mask]
720
+
721
+ cnt = sung_centroids.shape[0]
722
+ if cnt > 1:
723
+ curr_c = sung_centroids[:-1]
724
+ next_c = sung_centroids[1:]
725
+
726
+ # Check non-decreasing order with overlap tolerance
727
+ non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
728
+ pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
729
+ monotonicity_score = non_decreasing / pairs
730
+ else:
731
+ monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
732
+
733
+ # ================= C. Path Confidence =================
734
+ # Average energy along the optimal path
735
+ if path_coords.shape[0] > 0:
736
+ p_rows = path_coords[:, 0]
737
+ p_cols = path_coords[:, 1]
738
+
739
+ path_energies = energy_matrix[p_rows, p_cols]
740
+ step_weights = torch.ones_like(path_energies)
741
+
742
+ # Lower weight for instrumental/tag steps
743
+ is_inst_step = (type_mask[p_rows] == 0)
744
+ step_weights[is_inst_step] = instrumental_weight
745
+
746
+ total_energy = (path_energies * step_weights).sum()
747
+ total_steps = step_weights.sum()
748
+
749
+ if total_steps > 0:
750
+ path_confidence = total_energy / total_steps
751
+ else:
752
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
753
+ else:
754
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
755
+
756
+ return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
757
+
758
+ def lyrics_alignment_info(
759
+ self,
760
+ attention_matrix: Union[torch.Tensor, np.ndarray],
761
+ token_ids: List[int],
762
+ custom_config: Dict[int, List[int]],
763
+ return_matrices: bool = False,
764
+ medfilt_width: int = 1
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Generates alignment path and processed matrices.
768
+
769
+ Args:
770
+ attention_matrix: Input attention tensor.
771
+ token_ids: Corresponding token IDs.
772
+ custom_config: Layer/Head configuration.
773
+ return_matrices: If True, returns matrices in the output.
774
+ medfilt_width: Median filter width.
775
+
776
+ Returns:
777
+ Dict or AlignmentInfo object containing path and masks.
778
+ """
779
+ calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
780
+ attention_matrix, custom_config, medfilt_width
781
+ )
782
+
783
+ if calc_matrix is None:
784
+ return {
785
+ "calc_matrix": None,
786
+ "error": "No valid attention heads found"
787
+ }
788
+
789
+ # 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
790
+ # Uses self.tokenizer internally
791
+ type_mask = self._generate_token_type_mask(token_ids)
792
+
793
+ # Safety check for shape mismatch
794
+ if len(type_mask) != energy_matrix.shape[0]:
795
+ # Fallback to all lyrics if shapes don't align
796
+ type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
797
+
798
+ # 2. DTW Pathfinding
799
+ # Using negative calc_matrix because DTW minimizes cost
800
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
801
+ path_coords = np.stack([text_indices, time_indices], axis=1)
802
+
803
+ return_dict = {
804
+ "path_coords": path_coords,
805
+ "type_mask": type_mask,
806
+ "energy_matrix": energy_matrix
807
+ }
808
+ if return_matrices:
809
+ return_dict['calc_matrix'] = calc_matrix
810
+ return_dict['vis_matrix'] = vis_matrix
811
+
812
+ return return_dict
813
+
814
+ def calculate_score(
815
+ self,
816
+ energy_matrix: Union[torch.Tensor, np.ndarray],
817
+ type_mask: Union[torch.Tensor, np.ndarray],
818
+ path_coords: Union[torch.Tensor, np.ndarray],
819
+ time_weight: float = 0.01,
820
+ overlap_frames: float = 9.0,
821
+ instrumental_weight: float = 1.0
822
+ ) -> Dict[str, Any]:
823
+ """
824
+ Calculates the final alignment score based on pre-computed components.
825
+
826
+ Args:
827
+ energy_matrix: Processed energy matrix.
828
+ type_mask: Token type mask.
829
+ path_coords: DTW path coordinates.
830
+ time_weight: Minimum energy threshold for monotonicity.
831
+ overlap_frames: Allowed backward movement frames.
832
+ instrumental_weight: Weight for non-lyric path steps.
833
+
834
+ Returns:
835
+ AlignmentScore object containing individual metrics and final score.
836
+ """
837
+ # Ensure Inputs are Tensors on the correct device
838
+ if not isinstance(energy_matrix, torch.Tensor):
839
+ energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32)
840
+
841
+ device = energy_matrix.device
842
+
843
+ if not isinstance(type_mask, torch.Tensor):
844
+ type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
845
+ else:
846
+ type_mask = type_mask.to(device=device, dtype=torch.long)
847
+
848
+ if not isinstance(path_coords, torch.Tensor):
849
+ path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
850
+ else:
851
+ path_coords = path_coords.to(device=device, dtype=torch.long)
852
+
853
+ # Compute Metrics
854
+ coverage, monotonicity, confidence = self._compute_alignment_metrics(
855
+ energy_matrix=energy_matrix,
856
+ path_coords=path_coords,
857
+ type_mask=type_mask,
858
+ time_weight=time_weight,
859
+ overlap_frames=overlap_frames,
860
+ instrumental_weight=instrumental_weight
861
+ )
862
+
863
+ # Final Score Calculation
864
+ # (Cov^2 * Mono^2 * Conf)
865
+ final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
866
+ final_score = float(np.clip(final_score, 0.0, 1.0))
867
+
868
+ return {
869
+ "lyrics_score": round(final_score, 4)
870
+ }
acestep/gradio_ui/events/__init__.py CHANGED
@@ -336,7 +336,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
336
  # Use default argument to capture btn_idx value at definition time (Python closure fix)
337
  def make_score_handler(idx):
338
  return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
339
- llm_handler, idx, scale, batch_idx, queue
340
  )
341
 
342
  for btn_idx in range(1, 9):
 
336
  # Use default argument to capture btn_idx value at definition time (Python closure fix)
337
  def make_score_handler(idx):
338
  return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
339
+ dit_handler, llm_handler, idx, scale, batch_idx, queue
340
  )
341
 
342
  for btn_idx in range(1, 9):
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -714,7 +714,22 @@ def generate_with_progress(
714
 
715
 
716
 
717
- def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
  """
719
  Calculate PMI-based quality score for generated audio.
720
 
@@ -733,6 +748,9 @@ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_me
733
  audio_duration: Audio duration value
734
  vocal_language: Vocal language value
735
  score_scale: Sensitivity scale parameter
 
 
 
736
 
737
  Returns:
738
  Score display string
@@ -791,7 +809,37 @@ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_me
791
  topk=10,
792
  score_scale=score_scale
793
  )
794
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
795
  # Format display string with per-condition breakdown
796
  if global_score == 0.0 and not scores_per_condition:
797
  return t("messages.score_failed", error=status)
@@ -804,12 +852,17 @@ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_me
804
  )
805
 
806
  conditions_display = "\n".join(condition_lines) if condition_lines else " (no conditions)"
807
-
808
- return (
809
  f"✅ Global Quality Score: {global_score:.4f} (0-1, higher=better)\n\n"
810
- f"📊 Per-Condition Scores (0-1):\n{conditions_display}\n\n"
811
- f"Note: Metadata uses Top-k Recall, Caption/Lyrics use PMI\n"
812
  )
 
 
 
 
 
 
813
 
814
  except Exception as e:
815
  import traceback
@@ -817,12 +870,19 @@ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_me
817
  return error_msg
818
 
819
 
820
- def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale, current_batch_index, batch_queue):
 
 
 
 
 
 
821
  """
822
  Calculate PMI-based quality score - REFACTORED to read from batch_queue only.
823
  This ensures scoring uses the actual generation parameters, not current UI values.
824
 
825
  Args:
 
826
  llm_handler: LLM handler instance
827
  sample_idx: Which sample to score (1-8)
828
  score_scale: Sensitivity scale parameter (tool setting, can be from UI)
@@ -843,6 +903,7 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
843
  time_signature = params.get("time_signature", "")
844
  audio_duration = params.get("audio_duration", -1)
845
  vocal_language = params.get("vocal_language", "")
 
846
 
847
  # Get LM metadata from batch_data (if it was saved during generation)
848
  lm_metadata = batch_data.get("lm_generated_metadata", None)
@@ -862,13 +923,51 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
862
  else:
863
  # Single mode: all samples use same codes
864
  audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
865
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866
  # Calculate score using historical parameters
867
  score_display = calculate_score_handler(
868
  llm_handler,
869
  audio_codes_str, caption, lyrics, lm_metadata,
870
  bpm, key_scale, time_signature, audio_duration, vocal_language,
871
- score_scale
 
 
 
872
  )
873
 
874
  # Update batch_queue with the calculated score
 
714
 
715
 
716
 
717
+ def calculate_score_handler(
718
+ llm_handler,
719
+ audio_codes_str,
720
+ caption,
721
+ lyrics,
722
+ lm_metadata,
723
+ bpm,
724
+ key_scale,
725
+ time_signature,
726
+ audio_duration,
727
+ vocal_language,
728
+ score_scale,
729
+ dit_handler,
730
+ extra_tensor_data,
731
+ inference_steps,
732
+ ):
733
  """
734
  Calculate PMI-based quality score for generated audio.
735
 
 
748
  audio_duration: Audio duration value
749
  vocal_language: Vocal language value
750
  score_scale: Sensitivity scale parameter
751
+ dit_handler: DiT handler instance (for alignment scoring)
752
+ extra_tensor_data: Dictionary containing tensors for the specific sample
753
+ inference_steps: Number of inference steps used
754
 
755
  Returns:
756
  Score display string
 
809
  topk=10,
810
  score_scale=score_scale
811
  )
812
+
813
+ alignment_report = ""
814
+
815
+ # Only calculate if we have the handler, tensor data, and actual lyrics
816
+ if dit_handler and extra_tensor_data and lyrics and lyrics.strip():
817
+ try:
818
+ align_result = dit_handler.get_lyric_score(
819
+ pred_latent=extra_tensor_data.get('pred_latent'),
820
+ encoder_hidden_states=extra_tensor_data.get('encoder_hidden_states'),
821
+ encoder_attention_mask=extra_tensor_data.get('encoder_attention_mask'),
822
+ context_latents=extra_tensor_data.get('context_latents'),
823
+ lyric_token_ids=extra_tensor_data.get('lyric_token_ids'),
824
+ vocal_language=vocal_language or "en",
825
+ inference_steps=int(inference_steps),
826
+ seed=42,
827
+ )
828
+
829
+ if align_result.get("success"):
830
+ lm_align_score = align_result.get("lm_score", 0.0)
831
+ dit_align_score = align_result.get("dit_score", 0.0)
832
+ alignment_report = (
833
+ f" • llm lyrics alignment score: {lm_align_score:.4f}\n"
834
+ f" • dit lyrics alignment score: {dit_align_score:.4f}\n"
835
+ "\n(Measures how well lyrics timestamps match audio energy using Cross-Attention)"
836
+ )
837
+ else:
838
+ align_err = align_result.get("error", "Unknown error")
839
+ alignment_report = f"\n⚠️ Alignment Score Failed: {align_err}"
840
+ except Exception as e:
841
+ alignment_report = f"\n⚠️ Alignment Score Error: {str(e)}"
842
+
843
  # Format display string with per-condition breakdown
844
  if global_score == 0.0 and not scores_per_condition:
845
  return t("messages.score_failed", error=status)
 
852
  )
853
 
854
  conditions_display = "\n".join(condition_lines) if condition_lines else " (no conditions)"
855
+
856
+ final_output = (
857
  f"✅ Global Quality Score: {global_score:.4f} (0-1, higher=better)\n\n"
858
+ f"📊 Per-Condition Scores (0-1):\n{conditions_display}\n"
 
859
  )
860
+
861
+ if alignment_report:
862
+ final_output += alignment_report + "\n"
863
+
864
+ final_output += "Note: Metadata uses Top-k Recall, Caption/Lyrics use PMI"
865
+ return final_output
866
 
867
  except Exception as e:
868
  import traceback
 
870
  return error_msg
871
 
872
 
873
+ def calculate_score_handler_with_selection(
874
+ dit_handler,
875
+ llm_handler,
876
+ sample_idx,
877
+ score_scale,
878
+ current_batch_index,
879
+ batch_queue):
880
  """
881
  Calculate PMI-based quality score - REFACTORED to read from batch_queue only.
882
  This ensures scoring uses the actual generation parameters, not current UI values.
883
 
884
  Args:
885
+ dit_handler: DiT Handler
886
  llm_handler: LLM handler instance
887
  sample_idx: Which sample to score (1-8)
888
  score_scale: Sensitivity scale parameter (tool setting, can be from UI)
 
903
  time_signature = params.get("time_signature", "")
904
  audio_duration = params.get("audio_duration", -1)
905
  vocal_language = params.get("vocal_language", "")
906
+ inference_steps = params.get("inference_steps", 8)
907
 
908
  # Get LM metadata from batch_data (if it was saved during generation)
909
  lm_metadata = batch_data.get("lm_generated_metadata", None)
 
923
  else:
924
  # Single mode: all samples use same codes
925
  audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
926
+
927
+ # Extract Tensor Data for Alignment Score (Extra Outputs)
928
+ extra_tensor_data = None
929
+ extra_outputs = batch_data.get("extra_outputs", {})
930
+
931
+ # Only proceed if we have tensors and a valid index
932
+ if extra_outputs and dit_handler:
933
+ pred_latents = extra_outputs.get("pred_latents")
934
+ # Ensure we have the critical tensor to check batch size
935
+ if pred_latents is not None:
936
+ sample_idx_0based = sample_idx - 1
937
+ batch_size = pred_latents.shape[0]
938
+
939
+ if 0 <= sample_idx_0based < batch_size:
940
+ # Slice tensors for this specific sample (keep dimension [1, ...])
941
+ # We assume all stored tensors are aligned in batch dim 0
942
+ try:
943
+ extra_tensor_data = {
944
+ "pred_latent": pred_latents[sample_idx_0based:sample_idx_0based + 1],
945
+ "encoder_hidden_states": extra_outputs.get("encoder_hidden_states")[
946
+ sample_idx_0based:sample_idx_0based + 1],
947
+ "encoder_attention_mask": extra_outputs.get("encoder_attention_mask")[
948
+ sample_idx_0based:sample_idx_0based + 1],
949
+ "context_latents": extra_outputs.get("context_latents")[
950
+ sample_idx_0based:sample_idx_0based + 1],
951
+ "lyric_token_ids": extra_outputs.get("lyric_token_idss")[
952
+ sample_idx_0based:sample_idx_0based + 1]
953
+ }
954
+
955
+ # Verify no None values in the sliced dict
956
+ if any(v is None for v in extra_tensor_data.values()):
957
+ extra_tensor_data = None
958
+ except Exception as e:
959
+ print(f"Error slicing tensor data for score: {e}")
960
+ extra_tensor_data = None
961
+
962
  # Calculate score using historical parameters
963
  score_display = calculate_score_handler(
964
  llm_handler,
965
  audio_codes_str, caption, lyrics, lm_metadata,
966
  bpm, key_scale, time_signature, audio_duration, vocal_language,
967
+ score_scale,
968
+ dit_handler,
969
+ extra_tensor_data,
970
+ inference_steps,
971
  )
972
 
973
  # Update batch_queue with the calculated score
acestep/handler.py CHANGED
@@ -31,7 +31,7 @@ from acestep.constants import (
31
  SFT_GEN_PROMPT,
32
  DEFAULT_DIT_INSTRUCTION,
33
  )
34
- from acestep.dit_alignment_score import MusicStampsAligner
35
 
36
 
37
  warnings.filterwarnings("ignore")
@@ -2553,3 +2553,229 @@ class AceStepHandler:
2553
  "success": False,
2554
  "error": error_msg
2555
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  SFT_GEN_PROMPT,
32
  DEFAULT_DIT_INSTRUCTION,
33
  )
34
+ from acestep.dit_alignment_score import MusicStampsAligner, MusicLyricScorer
35
 
36
 
37
  warnings.filterwarnings("ignore")
 
2553
  "success": False,
2554
  "error": error_msg
2555
  }
2556
+
2557
+ @torch.no_grad()
2558
+ def get_lyric_score(
2559
+ self,
2560
+ pred_latent: torch.Tensor,
2561
+ encoder_hidden_states: torch.Tensor,
2562
+ encoder_attention_mask: torch.Tensor,
2563
+ context_latents: torch.Tensor,
2564
+ lyric_token_ids: torch.Tensor,
2565
+ vocal_language: str = "en",
2566
+ inference_steps: int = 8,
2567
+ seed: int = 42,
2568
+ custom_layers_config: Optional[Dict] = None,
2569
+ ) -> Dict[str, Any]:
2570
+ """
2571
+ Calculate both LM and DiT alignment scores in one pass.
2572
+
2573
+ - lm_score: Checks structural alignment using pure noise at t=1.0.
2574
+ - dit_score: Checks denoising alignment using regressed latents at t=1/steps.
2575
+
2576
+ Args:
2577
+ pred_latent: Generated latent tensor [batch, T, D]
2578
+ encoder_hidden_states: Cached encoder hidden states
2579
+ encoder_attention_mask: Cached encoder attention mask
2580
+ context_latents: Cached context latents
2581
+ lyric_token_ids: Tokenized lyrics tensor [batch, seq_len]
2582
+ vocal_language: Language code for lyrics header parsing
2583
+ inference_steps: Number of inference steps (for noise level calculation)
2584
+ seed: Random seed for noise generation
2585
+ custom_layers_config: Dict mapping layer indices to head indices
2586
+
2587
+ Returns:
2588
+ Dict containing:
2589
+ - lm_score: float
2590
+ - dit_score: float
2591
+ - success: Whether generation succeeded
2592
+ - error: Error message if failed
2593
+ """
2594
+ from transformers.cache_utils import EncoderDecoderCache, DynamicCache
2595
+
2596
+ if self.model is None:
2597
+ return {
2598
+ "lm_score": 0.0,
2599
+ "dit_score": 0.0,
2600
+ "success": False,
2601
+ "error": "Model not initialized"
2602
+ }
2603
+
2604
+ if custom_layers_config is None:
2605
+ custom_layers_config = self.custom_layers_config
2606
+
2607
+ try:
2608
+ # Move tensors to device
2609
+ device = self.device
2610
+ dtype = self.dtype
2611
+
2612
+ pred_latent = pred_latent.to(device=device, dtype=dtype)
2613
+ encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype)
2614
+ encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype)
2615
+ context_latents = context_latents.to(device=device, dtype=dtype)
2616
+
2617
+ bsz = pred_latent.shape[0]
2618
+
2619
+ if seed is None:
2620
+ x0 = torch.randn_like(pred_latent)
2621
+ else:
2622
+ generator = torch.Generator(device=device).manual_seed(int(seed))
2623
+ x0 = torch.randn(pred_latent.shape, generator=generator, device=device, dtype=dtype)
2624
+
2625
+ # --- Input A: LM Score ---
2626
+ # t = 1.0, xt = Pure Noise
2627
+ t_lm = torch.tensor([1.0] * bsz, device=device, dtype=dtype)
2628
+ xt_lm = x0
2629
+
2630
+ # --- Input B: DiT Score ---
2631
+ # t = 1.0/steps, xt = Regressed Latent
2632
+ t_last_val = 1.0 / inference_steps
2633
+ t_dit = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype)
2634
+ # Flow Matching Regression: xt = t*x0 + (1-t)*x1
2635
+ xt_dit = t_last_val * x0 + (1.0 - t_last_val) * pred_latent
2636
+
2637
+ # Order: [Think_Batch, DiT_Batch]
2638
+ xt_in = torch.cat([xt_lm, xt_dit], dim=0)
2639
+ t_in = torch.cat([t_lm, t_dit], dim=0)
2640
+
2641
+ # Duplicate conditions
2642
+ encoder_hidden_states_in = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0)
2643
+ encoder_attention_mask_in = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0)
2644
+ context_latents_in = torch.cat([context_latents, context_latents], dim=0)
2645
+
2646
+ # Prepare Attention Mask
2647
+ latent_length = xt_in.shape[1]
2648
+ attention_mask_in = torch.ones(2 * bsz, latent_length, device=device, dtype=dtype)
2649
+ past_key_values = None
2650
+
2651
+ # Run decoder with output_attentions=True
2652
+ with self._load_model_context("model"):
2653
+ decoder = self.model.decoder
2654
+ if hasattr(decoder, 'eval'):
2655
+ decoder.eval()
2656
+
2657
+ decoder_outputs = decoder(
2658
+ hidden_states=xt_in,
2659
+ timestep=t_in,
2660
+ timestep_r=t_in,
2661
+ attention_mask=attention_mask_in,
2662
+ encoder_hidden_states=encoder_hidden_states_in,
2663
+ use_cache=False,
2664
+ past_key_values=past_key_values,
2665
+ encoder_attention_mask=encoder_attention_mask_in,
2666
+ context_latents=context_latents_in,
2667
+ output_attentions=True,
2668
+ custom_layers_config=custom_layers_config,
2669
+ enable_early_exit=True
2670
+ )
2671
+
2672
+ # Extract cross-attention matrices
2673
+ if decoder_outputs[2] is None:
2674
+ return {
2675
+ "lm_score": 0.0,
2676
+ "dit_score": 0.0,
2677
+ "success": False,
2678
+ "error": "Model did not return attentions"
2679
+ }
2680
+
2681
+ cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None)
2682
+
2683
+ captured_layers_list = []
2684
+ for layer_attn in cross_attns:
2685
+ if layer_attn is None:
2686
+ continue
2687
+
2688
+ # Only take conditional part (first half of batch)
2689
+ layer_matrix = layer_attn.transpose(-1, -2)
2690
+ captured_layers_list.append(layer_matrix)
2691
+
2692
+ if not captured_layers_list:
2693
+ return {
2694
+ "lm_score": 0.0,
2695
+ "dit_score": 0.0,
2696
+ "success": False,
2697
+ "error": "No valid attention layers returned"
2698
+ }
2699
+
2700
+ stacked = torch.stack(captured_layers_list)
2701
+
2702
+ all_layers_matrix_lm = stacked[:, :bsz, ...]
2703
+ all_layers_matrix_dit = stacked[:, bsz:, ...]
2704
+
2705
+ if bsz == 1:
2706
+ all_layers_matrix_lm = all_layers_matrix_lm.squeeze(1)
2707
+ all_layers_matrix_dit = all_layers_matrix_dit.squeeze(1)
2708
+ else:
2709
+ pass
2710
+
2711
+ # Process lyric token IDs to extract pure lyrics
2712
+ if isinstance(lyric_token_ids, torch.Tensor):
2713
+ raw_lyric_ids = lyric_token_ids[0].tolist()
2714
+ else:
2715
+ raw_lyric_ids = lyric_token_ids
2716
+
2717
+ # Parse header to find lyrics start position
2718
+ header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n"
2719
+ header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False)
2720
+ start_idx = len(header_ids)
2721
+
2722
+ # Find end of lyrics (before endoftext token)
2723
+ try:
2724
+ end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token
2725
+ except ValueError:
2726
+ end_idx = len(raw_lyric_ids)
2727
+
2728
+ pure_lyric_ids = raw_lyric_ids[start_idx:end_idx]
2729
+ if start_idx >= all_layers_matrix_lm.shape[-2]: # Check text dim
2730
+ return {
2731
+ "lm_score": 0.0,
2732
+ "dit_score": 0.0,
2733
+ "success": False,
2734
+ "error": "Lyrics indices out of bounds"
2735
+ }
2736
+
2737
+ pure_matrix_lm = all_layers_matrix_lm[..., start_idx:end_idx, :]
2738
+ pure_matrix_dit = all_layers_matrix_dit[..., start_idx:end_idx, :]
2739
+
2740
+ # Create aligner and calculate alignment info
2741
+ aligner = MusicLyricScorer(self.text_tokenizer)
2742
+
2743
+ def calculate_single_score(matrix):
2744
+ """Helper to run aligner on a matrix"""
2745
+ info = aligner.lyrics_alignment_info(
2746
+ attention_matrix=matrix,
2747
+ token_ids=pure_lyric_ids,
2748
+ custom_config=custom_layers_config,
2749
+ return_matrices=False,
2750
+ medfilt_width=1,
2751
+ )
2752
+ if info.get("energy_matrix") is None:
2753
+ return 0.0
2754
+
2755
+ res = aligner.calculate_score(
2756
+ energy_matrix=info["energy_matrix"],
2757
+ type_mask=info["type_mask"],
2758
+ path_coords=info["path_coords"],
2759
+ )
2760
+ # Return the final score (check return key)
2761
+ return res.get("lyrics_score", res.get("final_score", 0.0))
2762
+
2763
+ lm_score = calculate_single_score(pure_matrix_lm)
2764
+ dit_score = calculate_single_score(pure_matrix_dit)
2765
+
2766
+ return {
2767
+ "lm_score": lm_score,
2768
+ "dit_score": dit_score,
2769
+ "success": True,
2770
+ "error": None
2771
+ }
2772
+
2773
+ except Exception as e:
2774
+ error_msg = f"Error generating score: {str(e)}"
2775
+ logger.exception("[get_lyric_score] Failed")
2776
+ return {
2777
+ "lm_score": 0.0,
2778
+ "dit_score": 0.0,
2779
+ "success": False,
2780
+ "error": error_msg
2781
+ }