Spaces:
Running
on
A100
Running
on
A100
feat :sparkles: : add lyrics alignment scores
Browse files- acestep/dit_alignment_score.py +324 -1
- acestep/gradio_ui/events/__init__.py +1 -1
- acestep/gradio_ui/events/results_handlers.py +108 -9
- acestep/handler.py +227 -1
acestep/dit_alignment_score.py
CHANGED
|
@@ -11,7 +11,7 @@ import torch
|
|
| 11 |
import numpy as np
|
| 12 |
import torch.nn.functional as F
|
| 13 |
from dataclasses import dataclass, asdict
|
| 14 |
-
from typing import List, Dict, Any, Optional
|
| 15 |
|
| 16 |
|
| 17 |
# ================= Data Classes =================
|
|
@@ -545,3 +545,326 @@ class MusicStampsAligner:
|
|
| 545 |
"lrc_text": lrc_text
|
| 546 |
}
|
| 547 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import torch.nn.functional as F
|
| 13 |
from dataclasses import dataclass, asdict
|
| 14 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 15 |
|
| 16 |
|
| 17 |
# ================= Data Classes =================
|
|
|
|
| 545 |
"lrc_text": lrc_text
|
| 546 |
}
|
| 547 |
|
| 548 |
+
|
| 549 |
+
class MusicLyricScorer:
|
| 550 |
+
"""
|
| 551 |
+
Scorer class for evaluating lyrics-to-audio alignment quality.
|
| 552 |
+
|
| 553 |
+
Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
|
| 554 |
+
using tensor operations for potential differentiability or GPU acceleration.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def __init__(self, tokenizer: Any):
|
| 558 |
+
"""
|
| 559 |
+
Initialize the aligner.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
tokenizer: Tokenizer instance (must implement .decode()).
|
| 563 |
+
"""
|
| 564 |
+
self.tokenizer = tokenizer
|
| 565 |
+
|
| 566 |
+
def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
|
| 567 |
+
"""
|
| 568 |
+
Generate a mask distinguishing lyrics (1) from structural tags (0).
|
| 569 |
+
Uses self.tokenizer to decode tokens.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
token_ids: List of token IDs.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
Numpy array of shape [len(token_ids)] with 1 or 0.
|
| 576 |
+
"""
|
| 577 |
+
decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
|
| 578 |
+
mask = np.ones(len(token_ids), dtype=np.int32)
|
| 579 |
+
in_bracket = False
|
| 580 |
+
|
| 581 |
+
for i, token_str in enumerate(decoded_tokens):
|
| 582 |
+
if '[' in token_str:
|
| 583 |
+
in_bracket = True
|
| 584 |
+
if in_bracket:
|
| 585 |
+
mask[i] = 0
|
| 586 |
+
if ']' in token_str:
|
| 587 |
+
in_bracket = False
|
| 588 |
+
mask[i] = 0
|
| 589 |
+
return mask
|
| 590 |
+
|
| 591 |
+
def _preprocess_attention(
|
| 592 |
+
self,
|
| 593 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 594 |
+
custom_config: Dict[int, List[int]],
|
| 595 |
+
medfilt_width: int = 1
|
| 596 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
|
| 597 |
+
"""
|
| 598 |
+
Extracts and normalizes the attention matrix.
|
| 599 |
+
|
| 600 |
+
Logic V4: Uses Min-Max normalization to highlight energy differences.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
|
| 604 |
+
custom_config: Config mapping layers to heads.
|
| 605 |
+
medfilt_width: Width for median filtering.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
|
| 609 |
+
"""
|
| 610 |
+
# 1. Prepare Tensor
|
| 611 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 612 |
+
weights = torch.tensor(attention_matrix)
|
| 613 |
+
else:
|
| 614 |
+
weights = attention_matrix.clone()
|
| 615 |
+
weights = weights.cpu().float()
|
| 616 |
+
|
| 617 |
+
# 2. Select Heads based on config
|
| 618 |
+
selected_tensors = []
|
| 619 |
+
for layer_idx, head_indices in custom_config.items():
|
| 620 |
+
for head_idx in head_indices:
|
| 621 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 622 |
+
selected_tensors.append(weights[layer_idx, head_idx])
|
| 623 |
+
|
| 624 |
+
if not selected_tensors:
|
| 625 |
+
return None, None, None
|
| 626 |
+
|
| 627 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 628 |
+
|
| 629 |
+
# 3. Average Heads
|
| 630 |
+
avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
|
| 631 |
+
|
| 632 |
+
# 4. Preprocessing Logic
|
| 633 |
+
# Min-Max normalization preserving energy distribution
|
| 634 |
+
# Median filter is applied to the energy matrix
|
| 635 |
+
energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
|
| 636 |
+
energy_matrix = energy_tensor.numpy()
|
| 637 |
+
|
| 638 |
+
e_min, e_max = energy_matrix.min(), energy_matrix.max()
|
| 639 |
+
|
| 640 |
+
if e_max - e_min > 1e-9:
|
| 641 |
+
energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
|
| 642 |
+
else:
|
| 643 |
+
energy_matrix = np.zeros_like(energy_matrix)
|
| 644 |
+
|
| 645 |
+
# Contrast enhancement for DTW pathfinding
|
| 646 |
+
# calc_matrix is used for pathfinding, energy_matrix for scoring
|
| 647 |
+
calc_matrix = energy_matrix ** 2
|
| 648 |
+
|
| 649 |
+
return calc_matrix, energy_matrix, avg_weights
|
| 650 |
+
|
| 651 |
+
def _compute_alignment_metrics(
|
| 652 |
+
self,
|
| 653 |
+
energy_matrix: torch.Tensor,
|
| 654 |
+
path_coords: torch.Tensor,
|
| 655 |
+
type_mask: torch.Tensor,
|
| 656 |
+
time_weight: float = 0.01,
|
| 657 |
+
overlap_frames: float = 9.0,
|
| 658 |
+
instrumental_weight: float = 1.0
|
| 659 |
+
) -> Tuple[float, float, float]:
|
| 660 |
+
"""
|
| 661 |
+
Core metric calculation logic using high-precision Tensor operations.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
energy_matrix: Normalized energy [Rows, Cols].
|
| 665 |
+
path_coords: DTW path coordinates [Steps, 2].
|
| 666 |
+
type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
|
| 667 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 668 |
+
overlap_frames: Allowed overlap for monotonicity check.
|
| 669 |
+
instrumental_weight: Weight for non-lyric tokens in confidence calc.
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
Tuple of (coverage, monotonicity, confidence).
|
| 673 |
+
"""
|
| 674 |
+
# Ensure high precision for internal calculation
|
| 675 |
+
energy_matrix = energy_matrix.to(dtype=torch.float64)
|
| 676 |
+
path_coords = path_coords.long()
|
| 677 |
+
type_mask = type_mask.long()
|
| 678 |
+
|
| 679 |
+
device = energy_matrix.device
|
| 680 |
+
rows, cols = energy_matrix.shape
|
| 681 |
+
|
| 682 |
+
is_lyrics_row = (type_mask == 1)
|
| 683 |
+
|
| 684 |
+
# ================= A. Coverage Score =================
|
| 685 |
+
# Ratio of lyric lines that have significant energy peak
|
| 686 |
+
row_max_energies = energy_matrix.max(dim=1).values
|
| 687 |
+
total_sung_rows = is_lyrics_row.sum().double()
|
| 688 |
+
|
| 689 |
+
coverage_threshold = 0.1
|
| 690 |
+
valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
|
| 691 |
+
valid_sung_rows = valid_sung_mask.sum().double()
|
| 692 |
+
|
| 693 |
+
if total_sung_rows > 0:
|
| 694 |
+
coverage_score = valid_sung_rows / total_sung_rows
|
| 695 |
+
else:
|
| 696 |
+
coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 697 |
+
|
| 698 |
+
# ================= B. Monotonicity Score =================
|
| 699 |
+
# Check if the "center of mass" of lyric lines moves forward in time
|
| 700 |
+
col_indices = torch.arange(cols, device=device, dtype=torch.float64)
|
| 701 |
+
|
| 702 |
+
# Zero out low energy noise
|
| 703 |
+
weights = torch.where(
|
| 704 |
+
energy_matrix > time_weight,
|
| 705 |
+
energy_matrix,
|
| 706 |
+
torch.zeros_like(energy_matrix)
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
sum_w = weights.sum(dim=1)
|
| 710 |
+
sum_t = (weights * col_indices).sum(dim=1)
|
| 711 |
+
|
| 712 |
+
# Calculate centroids
|
| 713 |
+
centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
|
| 714 |
+
valid_w_mask = sum_w > 1e-9
|
| 715 |
+
centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
|
| 716 |
+
|
| 717 |
+
# Extract sequence of valid lyrics centroids
|
| 718 |
+
valid_sequence_mask = is_lyrics_row & (centroids >= 0)
|
| 719 |
+
sung_centroids = centroids[valid_sequence_mask]
|
| 720 |
+
|
| 721 |
+
cnt = sung_centroids.shape[0]
|
| 722 |
+
if cnt > 1:
|
| 723 |
+
curr_c = sung_centroids[:-1]
|
| 724 |
+
next_c = sung_centroids[1:]
|
| 725 |
+
|
| 726 |
+
# Check non-decreasing order with overlap tolerance
|
| 727 |
+
non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
|
| 728 |
+
pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
|
| 729 |
+
monotonicity_score = non_decreasing / pairs
|
| 730 |
+
else:
|
| 731 |
+
monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 732 |
+
|
| 733 |
+
# ================= C. Path Confidence =================
|
| 734 |
+
# Average energy along the optimal path
|
| 735 |
+
if path_coords.shape[0] > 0:
|
| 736 |
+
p_rows = path_coords[:, 0]
|
| 737 |
+
p_cols = path_coords[:, 1]
|
| 738 |
+
|
| 739 |
+
path_energies = energy_matrix[p_rows, p_cols]
|
| 740 |
+
step_weights = torch.ones_like(path_energies)
|
| 741 |
+
|
| 742 |
+
# Lower weight for instrumental/tag steps
|
| 743 |
+
is_inst_step = (type_mask[p_rows] == 0)
|
| 744 |
+
step_weights[is_inst_step] = instrumental_weight
|
| 745 |
+
|
| 746 |
+
total_energy = (path_energies * step_weights).sum()
|
| 747 |
+
total_steps = step_weights.sum()
|
| 748 |
+
|
| 749 |
+
if total_steps > 0:
|
| 750 |
+
path_confidence = total_energy / total_steps
|
| 751 |
+
else:
|
| 752 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 753 |
+
else:
|
| 754 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 755 |
+
|
| 756 |
+
return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
|
| 757 |
+
|
| 758 |
+
def lyrics_alignment_info(
|
| 759 |
+
self,
|
| 760 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 761 |
+
token_ids: List[int],
|
| 762 |
+
custom_config: Dict[int, List[int]],
|
| 763 |
+
return_matrices: bool = False,
|
| 764 |
+
medfilt_width: int = 1
|
| 765 |
+
) -> Dict[str, Any]:
|
| 766 |
+
"""
|
| 767 |
+
Generates alignment path and processed matrices.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
attention_matrix: Input attention tensor.
|
| 771 |
+
token_ids: Corresponding token IDs.
|
| 772 |
+
custom_config: Layer/Head configuration.
|
| 773 |
+
return_matrices: If True, returns matrices in the output.
|
| 774 |
+
medfilt_width: Median filter width.
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
Dict or AlignmentInfo object containing path and masks.
|
| 778 |
+
"""
|
| 779 |
+
calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
|
| 780 |
+
attention_matrix, custom_config, medfilt_width
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
if calc_matrix is None:
|
| 784 |
+
return {
|
| 785 |
+
"calc_matrix": None,
|
| 786 |
+
"error": "No valid attention heads found"
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
# 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
|
| 790 |
+
# Uses self.tokenizer internally
|
| 791 |
+
type_mask = self._generate_token_type_mask(token_ids)
|
| 792 |
+
|
| 793 |
+
# Safety check for shape mismatch
|
| 794 |
+
if len(type_mask) != energy_matrix.shape[0]:
|
| 795 |
+
# Fallback to all lyrics if shapes don't align
|
| 796 |
+
type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
|
| 797 |
+
|
| 798 |
+
# 2. DTW Pathfinding
|
| 799 |
+
# Using negative calc_matrix because DTW minimizes cost
|
| 800 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
|
| 801 |
+
path_coords = np.stack([text_indices, time_indices], axis=1)
|
| 802 |
+
|
| 803 |
+
return_dict = {
|
| 804 |
+
"path_coords": path_coords,
|
| 805 |
+
"type_mask": type_mask,
|
| 806 |
+
"energy_matrix": energy_matrix
|
| 807 |
+
}
|
| 808 |
+
if return_matrices:
|
| 809 |
+
return_dict['calc_matrix'] = calc_matrix
|
| 810 |
+
return_dict['vis_matrix'] = vis_matrix
|
| 811 |
+
|
| 812 |
+
return return_dict
|
| 813 |
+
|
| 814 |
+
def calculate_score(
|
| 815 |
+
self,
|
| 816 |
+
energy_matrix: Union[torch.Tensor, np.ndarray],
|
| 817 |
+
type_mask: Union[torch.Tensor, np.ndarray],
|
| 818 |
+
path_coords: Union[torch.Tensor, np.ndarray],
|
| 819 |
+
time_weight: float = 0.01,
|
| 820 |
+
overlap_frames: float = 9.0,
|
| 821 |
+
instrumental_weight: float = 1.0
|
| 822 |
+
) -> Dict[str, Any]:
|
| 823 |
+
"""
|
| 824 |
+
Calculates the final alignment score based on pre-computed components.
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
energy_matrix: Processed energy matrix.
|
| 828 |
+
type_mask: Token type mask.
|
| 829 |
+
path_coords: DTW path coordinates.
|
| 830 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 831 |
+
overlap_frames: Allowed backward movement frames.
|
| 832 |
+
instrumental_weight: Weight for non-lyric path steps.
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
AlignmentScore object containing individual metrics and final score.
|
| 836 |
+
"""
|
| 837 |
+
# Ensure Inputs are Tensors on the correct device
|
| 838 |
+
if not isinstance(energy_matrix, torch.Tensor):
|
| 839 |
+
energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32)
|
| 840 |
+
|
| 841 |
+
device = energy_matrix.device
|
| 842 |
+
|
| 843 |
+
if not isinstance(type_mask, torch.Tensor):
|
| 844 |
+
type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
|
| 845 |
+
else:
|
| 846 |
+
type_mask = type_mask.to(device=device, dtype=torch.long)
|
| 847 |
+
|
| 848 |
+
if not isinstance(path_coords, torch.Tensor):
|
| 849 |
+
path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
|
| 850 |
+
else:
|
| 851 |
+
path_coords = path_coords.to(device=device, dtype=torch.long)
|
| 852 |
+
|
| 853 |
+
# Compute Metrics
|
| 854 |
+
coverage, monotonicity, confidence = self._compute_alignment_metrics(
|
| 855 |
+
energy_matrix=energy_matrix,
|
| 856 |
+
path_coords=path_coords,
|
| 857 |
+
type_mask=type_mask,
|
| 858 |
+
time_weight=time_weight,
|
| 859 |
+
overlap_frames=overlap_frames,
|
| 860 |
+
instrumental_weight=instrumental_weight
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
# Final Score Calculation
|
| 864 |
+
# (Cov^2 * Mono^2 * Conf)
|
| 865 |
+
final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
|
| 866 |
+
final_score = float(np.clip(final_score, 0.0, 1.0))
|
| 867 |
+
|
| 868 |
+
return {
|
| 869 |
+
"lyrics_score": round(final_score, 4)
|
| 870 |
+
}
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -336,7 +336,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 336 |
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 337 |
def make_score_handler(idx):
|
| 338 |
return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
|
| 339 |
-
llm_handler, idx, scale, batch_idx, queue
|
| 340 |
)
|
| 341 |
|
| 342 |
for btn_idx in range(1, 9):
|
|
|
|
| 336 |
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 337 |
def make_score_handler(idx):
|
| 338 |
return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
|
| 339 |
+
dit_handler, llm_handler, idx, scale, batch_idx, queue
|
| 340 |
)
|
| 341 |
|
| 342 |
for btn_idx in range(1, 9):
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -714,7 +714,22 @@ def generate_with_progress(
|
|
| 714 |
|
| 715 |
|
| 716 |
|
| 717 |
-
def calculate_score_handler(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 718 |
"""
|
| 719 |
Calculate PMI-based quality score for generated audio.
|
| 720 |
|
|
@@ -733,6 +748,9 @@ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_me
|
|
| 733 |
audio_duration: Audio duration value
|
| 734 |
vocal_language: Vocal language value
|
| 735 |
score_scale: Sensitivity scale parameter
|
|
|
|
|
|
|
|
|
|
| 736 |
|
| 737 |
Returns:
|
| 738 |
Score display string
|
|
@@ -791,7 +809,37 @@ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_me
|
|
| 791 |
topk=10,
|
| 792 |
score_scale=score_scale
|
| 793 |
)
|
| 794 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
# Format display string with per-condition breakdown
|
| 796 |
if global_score == 0.0 and not scores_per_condition:
|
| 797 |
return t("messages.score_failed", error=status)
|
|
@@ -804,12 +852,17 @@ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_me
|
|
| 804 |
)
|
| 805 |
|
| 806 |
conditions_display = "\n".join(condition_lines) if condition_lines else " (no conditions)"
|
| 807 |
-
|
| 808 |
-
|
| 809 |
f"✅ Global Quality Score: {global_score:.4f} (0-1, higher=better)\n\n"
|
| 810 |
-
f"📊 Per-Condition Scores (0-1):\n{conditions_display}\n
|
| 811 |
-
f"Note: Metadata uses Top-k Recall, Caption/Lyrics use PMI\n"
|
| 812 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
|
| 814 |
except Exception as e:
|
| 815 |
import traceback
|
|
@@ -817,12 +870,19 @@ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_me
|
|
| 817 |
return error_msg
|
| 818 |
|
| 819 |
|
| 820 |
-
def calculate_score_handler_with_selection(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
"""
|
| 822 |
Calculate PMI-based quality score - REFACTORED to read from batch_queue only.
|
| 823 |
This ensures scoring uses the actual generation parameters, not current UI values.
|
| 824 |
|
| 825 |
Args:
|
|
|
|
| 826 |
llm_handler: LLM handler instance
|
| 827 |
sample_idx: Which sample to score (1-8)
|
| 828 |
score_scale: Sensitivity scale parameter (tool setting, can be from UI)
|
|
@@ -843,6 +903,7 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
|
|
| 843 |
time_signature = params.get("time_signature", "")
|
| 844 |
audio_duration = params.get("audio_duration", -1)
|
| 845 |
vocal_language = params.get("vocal_language", "")
|
|
|
|
| 846 |
|
| 847 |
# Get LM metadata from batch_data (if it was saved during generation)
|
| 848 |
lm_metadata = batch_data.get("lm_generated_metadata", None)
|
|
@@ -862,13 +923,51 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
|
|
| 862 |
else:
|
| 863 |
# Single mode: all samples use same codes
|
| 864 |
audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
|
| 865 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
# Calculate score using historical parameters
|
| 867 |
score_display = calculate_score_handler(
|
| 868 |
llm_handler,
|
| 869 |
audio_codes_str, caption, lyrics, lm_metadata,
|
| 870 |
bpm, key_scale, time_signature, audio_duration, vocal_language,
|
| 871 |
-
score_scale
|
|
|
|
|
|
|
|
|
|
| 872 |
)
|
| 873 |
|
| 874 |
# Update batch_queue with the calculated score
|
|
|
|
| 714 |
|
| 715 |
|
| 716 |
|
| 717 |
+
def calculate_score_handler(
|
| 718 |
+
llm_handler,
|
| 719 |
+
audio_codes_str,
|
| 720 |
+
caption,
|
| 721 |
+
lyrics,
|
| 722 |
+
lm_metadata,
|
| 723 |
+
bpm,
|
| 724 |
+
key_scale,
|
| 725 |
+
time_signature,
|
| 726 |
+
audio_duration,
|
| 727 |
+
vocal_language,
|
| 728 |
+
score_scale,
|
| 729 |
+
dit_handler,
|
| 730 |
+
extra_tensor_data,
|
| 731 |
+
inference_steps,
|
| 732 |
+
):
|
| 733 |
"""
|
| 734 |
Calculate PMI-based quality score for generated audio.
|
| 735 |
|
|
|
|
| 748 |
audio_duration: Audio duration value
|
| 749 |
vocal_language: Vocal language value
|
| 750 |
score_scale: Sensitivity scale parameter
|
| 751 |
+
dit_handler: DiT handler instance (for alignment scoring)
|
| 752 |
+
extra_tensor_data: Dictionary containing tensors for the specific sample
|
| 753 |
+
inference_steps: Number of inference steps used
|
| 754 |
|
| 755 |
Returns:
|
| 756 |
Score display string
|
|
|
|
| 809 |
topk=10,
|
| 810 |
score_scale=score_scale
|
| 811 |
)
|
| 812 |
+
|
| 813 |
+
alignment_report = ""
|
| 814 |
+
|
| 815 |
+
# Only calculate if we have the handler, tensor data, and actual lyrics
|
| 816 |
+
if dit_handler and extra_tensor_data and lyrics and lyrics.strip():
|
| 817 |
+
try:
|
| 818 |
+
align_result = dit_handler.get_lyric_score(
|
| 819 |
+
pred_latent=extra_tensor_data.get('pred_latent'),
|
| 820 |
+
encoder_hidden_states=extra_tensor_data.get('encoder_hidden_states'),
|
| 821 |
+
encoder_attention_mask=extra_tensor_data.get('encoder_attention_mask'),
|
| 822 |
+
context_latents=extra_tensor_data.get('context_latents'),
|
| 823 |
+
lyric_token_ids=extra_tensor_data.get('lyric_token_ids'),
|
| 824 |
+
vocal_language=vocal_language or "en",
|
| 825 |
+
inference_steps=int(inference_steps),
|
| 826 |
+
seed=42,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
if align_result.get("success"):
|
| 830 |
+
lm_align_score = align_result.get("lm_score", 0.0)
|
| 831 |
+
dit_align_score = align_result.get("dit_score", 0.0)
|
| 832 |
+
alignment_report = (
|
| 833 |
+
f" • llm lyrics alignment score: {lm_align_score:.4f}\n"
|
| 834 |
+
f" • dit lyrics alignment score: {dit_align_score:.4f}\n"
|
| 835 |
+
"\n(Measures how well lyrics timestamps match audio energy using Cross-Attention)"
|
| 836 |
+
)
|
| 837 |
+
else:
|
| 838 |
+
align_err = align_result.get("error", "Unknown error")
|
| 839 |
+
alignment_report = f"\n⚠️ Alignment Score Failed: {align_err}"
|
| 840 |
+
except Exception as e:
|
| 841 |
+
alignment_report = f"\n⚠️ Alignment Score Error: {str(e)}"
|
| 842 |
+
|
| 843 |
# Format display string with per-condition breakdown
|
| 844 |
if global_score == 0.0 and not scores_per_condition:
|
| 845 |
return t("messages.score_failed", error=status)
|
|
|
|
| 852 |
)
|
| 853 |
|
| 854 |
conditions_display = "\n".join(condition_lines) if condition_lines else " (no conditions)"
|
| 855 |
+
|
| 856 |
+
final_output = (
|
| 857 |
f"✅ Global Quality Score: {global_score:.4f} (0-1, higher=better)\n\n"
|
| 858 |
+
f"📊 Per-Condition Scores (0-1):\n{conditions_display}\n"
|
|
|
|
| 859 |
)
|
| 860 |
+
|
| 861 |
+
if alignment_report:
|
| 862 |
+
final_output += alignment_report + "\n"
|
| 863 |
+
|
| 864 |
+
final_output += "Note: Metadata uses Top-k Recall, Caption/Lyrics use PMI"
|
| 865 |
+
return final_output
|
| 866 |
|
| 867 |
except Exception as e:
|
| 868 |
import traceback
|
|
|
|
| 870 |
return error_msg
|
| 871 |
|
| 872 |
|
| 873 |
+
def calculate_score_handler_with_selection(
|
| 874 |
+
dit_handler,
|
| 875 |
+
llm_handler,
|
| 876 |
+
sample_idx,
|
| 877 |
+
score_scale,
|
| 878 |
+
current_batch_index,
|
| 879 |
+
batch_queue):
|
| 880 |
"""
|
| 881 |
Calculate PMI-based quality score - REFACTORED to read from batch_queue only.
|
| 882 |
This ensures scoring uses the actual generation parameters, not current UI values.
|
| 883 |
|
| 884 |
Args:
|
| 885 |
+
dit_handler: DiT Handler
|
| 886 |
llm_handler: LLM handler instance
|
| 887 |
sample_idx: Which sample to score (1-8)
|
| 888 |
score_scale: Sensitivity scale parameter (tool setting, can be from UI)
|
|
|
|
| 903 |
time_signature = params.get("time_signature", "")
|
| 904 |
audio_duration = params.get("audio_duration", -1)
|
| 905 |
vocal_language = params.get("vocal_language", "")
|
| 906 |
+
inference_steps = params.get("inference_steps", 8)
|
| 907 |
|
| 908 |
# Get LM metadata from batch_data (if it was saved during generation)
|
| 909 |
lm_metadata = batch_data.get("lm_generated_metadata", None)
|
|
|
|
| 923 |
else:
|
| 924 |
# Single mode: all samples use same codes
|
| 925 |
audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
|
| 926 |
+
|
| 927 |
+
# Extract Tensor Data for Alignment Score (Extra Outputs)
|
| 928 |
+
extra_tensor_data = None
|
| 929 |
+
extra_outputs = batch_data.get("extra_outputs", {})
|
| 930 |
+
|
| 931 |
+
# Only proceed if we have tensors and a valid index
|
| 932 |
+
if extra_outputs and dit_handler:
|
| 933 |
+
pred_latents = extra_outputs.get("pred_latents")
|
| 934 |
+
# Ensure we have the critical tensor to check batch size
|
| 935 |
+
if pred_latents is not None:
|
| 936 |
+
sample_idx_0based = sample_idx - 1
|
| 937 |
+
batch_size = pred_latents.shape[0]
|
| 938 |
+
|
| 939 |
+
if 0 <= sample_idx_0based < batch_size:
|
| 940 |
+
# Slice tensors for this specific sample (keep dimension [1, ...])
|
| 941 |
+
# We assume all stored tensors are aligned in batch dim 0
|
| 942 |
+
try:
|
| 943 |
+
extra_tensor_data = {
|
| 944 |
+
"pred_latent": pred_latents[sample_idx_0based:sample_idx_0based + 1],
|
| 945 |
+
"encoder_hidden_states": extra_outputs.get("encoder_hidden_states")[
|
| 946 |
+
sample_idx_0based:sample_idx_0based + 1],
|
| 947 |
+
"encoder_attention_mask": extra_outputs.get("encoder_attention_mask")[
|
| 948 |
+
sample_idx_0based:sample_idx_0based + 1],
|
| 949 |
+
"context_latents": extra_outputs.get("context_latents")[
|
| 950 |
+
sample_idx_0based:sample_idx_0based + 1],
|
| 951 |
+
"lyric_token_ids": extra_outputs.get("lyric_token_idss")[
|
| 952 |
+
sample_idx_0based:sample_idx_0based + 1]
|
| 953 |
+
}
|
| 954 |
+
|
| 955 |
+
# Verify no None values in the sliced dict
|
| 956 |
+
if any(v is None for v in extra_tensor_data.values()):
|
| 957 |
+
extra_tensor_data = None
|
| 958 |
+
except Exception as e:
|
| 959 |
+
print(f"Error slicing tensor data for score: {e}")
|
| 960 |
+
extra_tensor_data = None
|
| 961 |
+
|
| 962 |
# Calculate score using historical parameters
|
| 963 |
score_display = calculate_score_handler(
|
| 964 |
llm_handler,
|
| 965 |
audio_codes_str, caption, lyrics, lm_metadata,
|
| 966 |
bpm, key_scale, time_signature, audio_duration, vocal_language,
|
| 967 |
+
score_scale,
|
| 968 |
+
dit_handler,
|
| 969 |
+
extra_tensor_data,
|
| 970 |
+
inference_steps,
|
| 971 |
)
|
| 972 |
|
| 973 |
# Update batch_queue with the calculated score
|
acestep/handler.py
CHANGED
|
@@ -31,7 +31,7 @@ from acestep.constants import (
|
|
| 31 |
SFT_GEN_PROMPT,
|
| 32 |
DEFAULT_DIT_INSTRUCTION,
|
| 33 |
)
|
| 34 |
-
from acestep.dit_alignment_score import MusicStampsAligner
|
| 35 |
|
| 36 |
|
| 37 |
warnings.filterwarnings("ignore")
|
|
@@ -2553,3 +2553,229 @@ class AceStepHandler:
|
|
| 2553 |
"success": False,
|
| 2554 |
"error": error_msg
|
| 2555 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
SFT_GEN_PROMPT,
|
| 32 |
DEFAULT_DIT_INSTRUCTION,
|
| 33 |
)
|
| 34 |
+
from acestep.dit_alignment_score import MusicStampsAligner, MusicLyricScorer
|
| 35 |
|
| 36 |
|
| 37 |
warnings.filterwarnings("ignore")
|
|
|
|
| 2553 |
"success": False,
|
| 2554 |
"error": error_msg
|
| 2555 |
}
|
| 2556 |
+
|
| 2557 |
+
@torch.no_grad()
|
| 2558 |
+
def get_lyric_score(
|
| 2559 |
+
self,
|
| 2560 |
+
pred_latent: torch.Tensor,
|
| 2561 |
+
encoder_hidden_states: torch.Tensor,
|
| 2562 |
+
encoder_attention_mask: torch.Tensor,
|
| 2563 |
+
context_latents: torch.Tensor,
|
| 2564 |
+
lyric_token_ids: torch.Tensor,
|
| 2565 |
+
vocal_language: str = "en",
|
| 2566 |
+
inference_steps: int = 8,
|
| 2567 |
+
seed: int = 42,
|
| 2568 |
+
custom_layers_config: Optional[Dict] = None,
|
| 2569 |
+
) -> Dict[str, Any]:
|
| 2570 |
+
"""
|
| 2571 |
+
Calculate both LM and DiT alignment scores in one pass.
|
| 2572 |
+
|
| 2573 |
+
- lm_score: Checks structural alignment using pure noise at t=1.0.
|
| 2574 |
+
- dit_score: Checks denoising alignment using regressed latents at t=1/steps.
|
| 2575 |
+
|
| 2576 |
+
Args:
|
| 2577 |
+
pred_latent: Generated latent tensor [batch, T, D]
|
| 2578 |
+
encoder_hidden_states: Cached encoder hidden states
|
| 2579 |
+
encoder_attention_mask: Cached encoder attention mask
|
| 2580 |
+
context_latents: Cached context latents
|
| 2581 |
+
lyric_token_ids: Tokenized lyrics tensor [batch, seq_len]
|
| 2582 |
+
vocal_language: Language code for lyrics header parsing
|
| 2583 |
+
inference_steps: Number of inference steps (for noise level calculation)
|
| 2584 |
+
seed: Random seed for noise generation
|
| 2585 |
+
custom_layers_config: Dict mapping layer indices to head indices
|
| 2586 |
+
|
| 2587 |
+
Returns:
|
| 2588 |
+
Dict containing:
|
| 2589 |
+
- lm_score: float
|
| 2590 |
+
- dit_score: float
|
| 2591 |
+
- success: Whether generation succeeded
|
| 2592 |
+
- error: Error message if failed
|
| 2593 |
+
"""
|
| 2594 |
+
from transformers.cache_utils import EncoderDecoderCache, DynamicCache
|
| 2595 |
+
|
| 2596 |
+
if self.model is None:
|
| 2597 |
+
return {
|
| 2598 |
+
"lm_score": 0.0,
|
| 2599 |
+
"dit_score": 0.0,
|
| 2600 |
+
"success": False,
|
| 2601 |
+
"error": "Model not initialized"
|
| 2602 |
+
}
|
| 2603 |
+
|
| 2604 |
+
if custom_layers_config is None:
|
| 2605 |
+
custom_layers_config = self.custom_layers_config
|
| 2606 |
+
|
| 2607 |
+
try:
|
| 2608 |
+
# Move tensors to device
|
| 2609 |
+
device = self.device
|
| 2610 |
+
dtype = self.dtype
|
| 2611 |
+
|
| 2612 |
+
pred_latent = pred_latent.to(device=device, dtype=dtype)
|
| 2613 |
+
encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype)
|
| 2614 |
+
encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype)
|
| 2615 |
+
context_latents = context_latents.to(device=device, dtype=dtype)
|
| 2616 |
+
|
| 2617 |
+
bsz = pred_latent.shape[0]
|
| 2618 |
+
|
| 2619 |
+
if seed is None:
|
| 2620 |
+
x0 = torch.randn_like(pred_latent)
|
| 2621 |
+
else:
|
| 2622 |
+
generator = torch.Generator(device=device).manual_seed(int(seed))
|
| 2623 |
+
x0 = torch.randn(pred_latent.shape, generator=generator, device=device, dtype=dtype)
|
| 2624 |
+
|
| 2625 |
+
# --- Input A: LM Score ---
|
| 2626 |
+
# t = 1.0, xt = Pure Noise
|
| 2627 |
+
t_lm = torch.tensor([1.0] * bsz, device=device, dtype=dtype)
|
| 2628 |
+
xt_lm = x0
|
| 2629 |
+
|
| 2630 |
+
# --- Input B: DiT Score ---
|
| 2631 |
+
# t = 1.0/steps, xt = Regressed Latent
|
| 2632 |
+
t_last_val = 1.0 / inference_steps
|
| 2633 |
+
t_dit = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype)
|
| 2634 |
+
# Flow Matching Regression: xt = t*x0 + (1-t)*x1
|
| 2635 |
+
xt_dit = t_last_val * x0 + (1.0 - t_last_val) * pred_latent
|
| 2636 |
+
|
| 2637 |
+
# Order: [Think_Batch, DiT_Batch]
|
| 2638 |
+
xt_in = torch.cat([xt_lm, xt_dit], dim=0)
|
| 2639 |
+
t_in = torch.cat([t_lm, t_dit], dim=0)
|
| 2640 |
+
|
| 2641 |
+
# Duplicate conditions
|
| 2642 |
+
encoder_hidden_states_in = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0)
|
| 2643 |
+
encoder_attention_mask_in = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0)
|
| 2644 |
+
context_latents_in = torch.cat([context_latents, context_latents], dim=0)
|
| 2645 |
+
|
| 2646 |
+
# Prepare Attention Mask
|
| 2647 |
+
latent_length = xt_in.shape[1]
|
| 2648 |
+
attention_mask_in = torch.ones(2 * bsz, latent_length, device=device, dtype=dtype)
|
| 2649 |
+
past_key_values = None
|
| 2650 |
+
|
| 2651 |
+
# Run decoder with output_attentions=True
|
| 2652 |
+
with self._load_model_context("model"):
|
| 2653 |
+
decoder = self.model.decoder
|
| 2654 |
+
if hasattr(decoder, 'eval'):
|
| 2655 |
+
decoder.eval()
|
| 2656 |
+
|
| 2657 |
+
decoder_outputs = decoder(
|
| 2658 |
+
hidden_states=xt_in,
|
| 2659 |
+
timestep=t_in,
|
| 2660 |
+
timestep_r=t_in,
|
| 2661 |
+
attention_mask=attention_mask_in,
|
| 2662 |
+
encoder_hidden_states=encoder_hidden_states_in,
|
| 2663 |
+
use_cache=False,
|
| 2664 |
+
past_key_values=past_key_values,
|
| 2665 |
+
encoder_attention_mask=encoder_attention_mask_in,
|
| 2666 |
+
context_latents=context_latents_in,
|
| 2667 |
+
output_attentions=True,
|
| 2668 |
+
custom_layers_config=custom_layers_config,
|
| 2669 |
+
enable_early_exit=True
|
| 2670 |
+
)
|
| 2671 |
+
|
| 2672 |
+
# Extract cross-attention matrices
|
| 2673 |
+
if decoder_outputs[2] is None:
|
| 2674 |
+
return {
|
| 2675 |
+
"lm_score": 0.0,
|
| 2676 |
+
"dit_score": 0.0,
|
| 2677 |
+
"success": False,
|
| 2678 |
+
"error": "Model did not return attentions"
|
| 2679 |
+
}
|
| 2680 |
+
|
| 2681 |
+
cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None)
|
| 2682 |
+
|
| 2683 |
+
captured_layers_list = []
|
| 2684 |
+
for layer_attn in cross_attns:
|
| 2685 |
+
if layer_attn is None:
|
| 2686 |
+
continue
|
| 2687 |
+
|
| 2688 |
+
# Only take conditional part (first half of batch)
|
| 2689 |
+
layer_matrix = layer_attn.transpose(-1, -2)
|
| 2690 |
+
captured_layers_list.append(layer_matrix)
|
| 2691 |
+
|
| 2692 |
+
if not captured_layers_list:
|
| 2693 |
+
return {
|
| 2694 |
+
"lm_score": 0.0,
|
| 2695 |
+
"dit_score": 0.0,
|
| 2696 |
+
"success": False,
|
| 2697 |
+
"error": "No valid attention layers returned"
|
| 2698 |
+
}
|
| 2699 |
+
|
| 2700 |
+
stacked = torch.stack(captured_layers_list)
|
| 2701 |
+
|
| 2702 |
+
all_layers_matrix_lm = stacked[:, :bsz, ...]
|
| 2703 |
+
all_layers_matrix_dit = stacked[:, bsz:, ...]
|
| 2704 |
+
|
| 2705 |
+
if bsz == 1:
|
| 2706 |
+
all_layers_matrix_lm = all_layers_matrix_lm.squeeze(1)
|
| 2707 |
+
all_layers_matrix_dit = all_layers_matrix_dit.squeeze(1)
|
| 2708 |
+
else:
|
| 2709 |
+
pass
|
| 2710 |
+
|
| 2711 |
+
# Process lyric token IDs to extract pure lyrics
|
| 2712 |
+
if isinstance(lyric_token_ids, torch.Tensor):
|
| 2713 |
+
raw_lyric_ids = lyric_token_ids[0].tolist()
|
| 2714 |
+
else:
|
| 2715 |
+
raw_lyric_ids = lyric_token_ids
|
| 2716 |
+
|
| 2717 |
+
# Parse header to find lyrics start position
|
| 2718 |
+
header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n"
|
| 2719 |
+
header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False)
|
| 2720 |
+
start_idx = len(header_ids)
|
| 2721 |
+
|
| 2722 |
+
# Find end of lyrics (before endoftext token)
|
| 2723 |
+
try:
|
| 2724 |
+
end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token
|
| 2725 |
+
except ValueError:
|
| 2726 |
+
end_idx = len(raw_lyric_ids)
|
| 2727 |
+
|
| 2728 |
+
pure_lyric_ids = raw_lyric_ids[start_idx:end_idx]
|
| 2729 |
+
if start_idx >= all_layers_matrix_lm.shape[-2]: # Check text dim
|
| 2730 |
+
return {
|
| 2731 |
+
"lm_score": 0.0,
|
| 2732 |
+
"dit_score": 0.0,
|
| 2733 |
+
"success": False,
|
| 2734 |
+
"error": "Lyrics indices out of bounds"
|
| 2735 |
+
}
|
| 2736 |
+
|
| 2737 |
+
pure_matrix_lm = all_layers_matrix_lm[..., start_idx:end_idx, :]
|
| 2738 |
+
pure_matrix_dit = all_layers_matrix_dit[..., start_idx:end_idx, :]
|
| 2739 |
+
|
| 2740 |
+
# Create aligner and calculate alignment info
|
| 2741 |
+
aligner = MusicLyricScorer(self.text_tokenizer)
|
| 2742 |
+
|
| 2743 |
+
def calculate_single_score(matrix):
|
| 2744 |
+
"""Helper to run aligner on a matrix"""
|
| 2745 |
+
info = aligner.lyrics_alignment_info(
|
| 2746 |
+
attention_matrix=matrix,
|
| 2747 |
+
token_ids=pure_lyric_ids,
|
| 2748 |
+
custom_config=custom_layers_config,
|
| 2749 |
+
return_matrices=False,
|
| 2750 |
+
medfilt_width=1,
|
| 2751 |
+
)
|
| 2752 |
+
if info.get("energy_matrix") is None:
|
| 2753 |
+
return 0.0
|
| 2754 |
+
|
| 2755 |
+
res = aligner.calculate_score(
|
| 2756 |
+
energy_matrix=info["energy_matrix"],
|
| 2757 |
+
type_mask=info["type_mask"],
|
| 2758 |
+
path_coords=info["path_coords"],
|
| 2759 |
+
)
|
| 2760 |
+
# Return the final score (check return key)
|
| 2761 |
+
return res.get("lyrics_score", res.get("final_score", 0.0))
|
| 2762 |
+
|
| 2763 |
+
lm_score = calculate_single_score(pure_matrix_lm)
|
| 2764 |
+
dit_score = calculate_single_score(pure_matrix_dit)
|
| 2765 |
+
|
| 2766 |
+
return {
|
| 2767 |
+
"lm_score": lm_score,
|
| 2768 |
+
"dit_score": dit_score,
|
| 2769 |
+
"success": True,
|
| 2770 |
+
"error": None
|
| 2771 |
+
}
|
| 2772 |
+
|
| 2773 |
+
except Exception as e:
|
| 2774 |
+
error_msg = f"Error generating score: {str(e)}"
|
| 2775 |
+
logger.exception("[get_lyric_score] Failed")
|
| 2776 |
+
return {
|
| 2777 |
+
"lm_score": 0.0,
|
| 2778 |
+
"dit_score": 0.0,
|
| 2779 |
+
"success": False,
|
| 2780 |
+
"error": error_msg
|
| 2781 |
+
}
|