|
|
""" |
|
|
DiT Alignment Score Module |
|
|
|
|
|
This module provides lyrics-to-audio alignment using cross-attention matrices |
|
|
from DiT model for generating LRC timestamps. |
|
|
|
|
|
Refactored from lyrics_alignment_infos.py for integration with ACE-Step. |
|
|
""" |
|
|
import numba |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
from dataclasses import dataclass, asdict |
|
|
from typing import List, Dict, Any, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TokenTimestamp: |
|
|
"""Stores per-token timing information.""" |
|
|
token_id: int |
|
|
text: str |
|
|
start: float |
|
|
end: float |
|
|
probability: float |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SentenceTimestamp: |
|
|
"""Stores per-sentence timing information with token list.""" |
|
|
text: str |
|
|
start: float |
|
|
end: float |
|
|
tokens: List[TokenTimestamp] |
|
|
confidence: float |
|
|
|
|
|
|
|
|
|
|
|
@numba.jit(nopython=True) |
|
|
def dtw_cpu(x: np.ndarray): |
|
|
""" |
|
|
Dynamic Time Warping algorithm optimized with Numba. |
|
|
|
|
|
Args: |
|
|
x: Cost matrix of shape [N, M] |
|
|
|
|
|
Returns: |
|
|
Tuple of (text_indices, time_indices) arrays |
|
|
""" |
|
|
N, M = x.shape |
|
|
|
|
|
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf |
|
|
trace = -np.ones((N + 1, M + 1), dtype=np.float32) |
|
|
cost[0, 0] = 0 |
|
|
|
|
|
for j in range(1, M + 1): |
|
|
for i in range(1, N + 1): |
|
|
c0 = cost[i - 1, j - 1] |
|
|
c1 = cost[i - 1, j] |
|
|
c2 = cost[i, j - 1] |
|
|
|
|
|
if c0 < c1 and c0 < c2: |
|
|
c, t = c0, 0 |
|
|
elif c1 < c0 and c1 < c2: |
|
|
c, t = c1, 1 |
|
|
else: |
|
|
c, t = c2, 2 |
|
|
|
|
|
cost[i, j] = x[i - 1, j - 1] + c |
|
|
trace[i, j] = t |
|
|
|
|
|
return _backtrace(trace, N, M) |
|
|
|
|
|
|
|
|
@numba.jit(nopython=True) |
|
|
def _backtrace(trace: np.ndarray, N: int, M: int): |
|
|
""" |
|
|
Optimized backtrace function for DTW. |
|
|
|
|
|
Args: |
|
|
trace: Trace matrix of shape (N+1, M+1) |
|
|
N, M: Original matrix dimensions |
|
|
|
|
|
Returns: |
|
|
Path array of shape (2, path_len) - first row is text indices, second is time indices |
|
|
""" |
|
|
|
|
|
trace[0, :] = 2 |
|
|
trace[:, 0] = 1 |
|
|
|
|
|
|
|
|
max_path_len = N + M |
|
|
path = np.zeros((2, max_path_len), dtype=np.int32) |
|
|
|
|
|
i, j = N, M |
|
|
path_idx = max_path_len - 1 |
|
|
|
|
|
while i > 0 or j > 0: |
|
|
path[0, path_idx] = i - 1 |
|
|
path[1, path_idx] = j - 1 |
|
|
path_idx -= 1 |
|
|
|
|
|
t = trace[i, j] |
|
|
if t == 0: |
|
|
i -= 1 |
|
|
j -= 1 |
|
|
elif t == 1: |
|
|
i -= 1 |
|
|
elif t == 2: |
|
|
j -= 1 |
|
|
else: |
|
|
break |
|
|
|
|
|
actual_len = max_path_len - path_idx - 1 |
|
|
return path[:, path_idx + 1:max_path_len] |
|
|
|
|
|
|
|
|
|
|
|
def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor: |
|
|
""" |
|
|
Apply median filter to tensor. |
|
|
|
|
|
Args: |
|
|
x: Input tensor |
|
|
filter_width: Width of median filter |
|
|
|
|
|
Returns: |
|
|
Filtered tensor |
|
|
""" |
|
|
pad_width = filter_width // 2 |
|
|
if x.shape[-1] <= pad_width: |
|
|
return x |
|
|
if x.ndim == 2: |
|
|
x = x[None, :] |
|
|
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") |
|
|
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] |
|
|
if result.ndim > 2: |
|
|
result = result.squeeze(0) |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
class MusicStampsAligner: |
|
|
""" |
|
|
Aligner class for generating lyrics timestamps from cross-attention matrices. |
|
|
|
|
|
Uses bidirectional consensus denoising and DTW for alignment. |
|
|
""" |
|
|
|
|
|
def __init__(self, tokenizer): |
|
|
""" |
|
|
Initialize the aligner. |
|
|
|
|
|
Args: |
|
|
tokenizer: Text tokenizer for decoding tokens |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def _apply_bidirectional_consensus( |
|
|
self, |
|
|
weights_stack: torch.Tensor, |
|
|
violence_level: float, |
|
|
medfilt_width: int |
|
|
) -> tuple: |
|
|
""" |
|
|
Core denoising logic using bidirectional consensus. |
|
|
|
|
|
Args: |
|
|
weights_stack: Attention weights [Heads, Tokens, Frames] |
|
|
violence_level: Denoising strength coefficient |
|
|
medfilt_width: Median filter width |
|
|
|
|
|
Returns: |
|
|
Tuple of (calc_matrix, energy_matrix) as numpy arrays |
|
|
""" |
|
|
|
|
|
row_prob = F.softmax(weights_stack, dim=-1) |
|
|
col_prob = F.softmax(weights_stack, dim=-2) |
|
|
processed = row_prob * col_prob |
|
|
|
|
|
|
|
|
row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True) |
|
|
processed = processed - (violence_level * row_medians) |
|
|
processed = torch.relu(processed) |
|
|
|
|
|
|
|
|
col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True) |
|
|
processed = processed - (violence_level * col_medians) |
|
|
processed = torch.relu(processed) |
|
|
|
|
|
|
|
|
processed = processed ** 2 |
|
|
|
|
|
|
|
|
energy_matrix = processed.mean(dim=0).cpu().numpy() |
|
|
|
|
|
|
|
|
std, mean = torch.std_mean(processed, unbiased=False) |
|
|
weights_processed = (processed - mean) / (std + 1e-9) |
|
|
|
|
|
|
|
|
weights_processed = median_filter(weights_processed, filter_width=medfilt_width) |
|
|
calc_matrix = weights_processed.mean(dim=0).numpy() |
|
|
|
|
|
return calc_matrix, energy_matrix |
|
|
|
|
|
def _preprocess_attention( |
|
|
self, |
|
|
attention_matrix: torch.Tensor, |
|
|
custom_config: Dict[int, List[int]], |
|
|
violence_level: float, |
|
|
medfilt_width: int = 7 |
|
|
) -> tuple: |
|
|
""" |
|
|
Preprocess attention matrix for alignment. |
|
|
|
|
|
Args: |
|
|
attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames] |
|
|
custom_config: Dict mapping layer indices to head indices |
|
|
violence_level: Denoising strength |
|
|
medfilt_width: Median filter width |
|
|
|
|
|
Returns: |
|
|
Tuple of (calc_matrix, energy_matrix, visual_matrix) |
|
|
""" |
|
|
if not isinstance(attention_matrix, torch.Tensor): |
|
|
weights = torch.tensor(attention_matrix) |
|
|
else: |
|
|
weights = attention_matrix.clone() |
|
|
|
|
|
weights = weights.cpu().float() |
|
|
|
|
|
selected_tensors = [] |
|
|
for layer_idx, head_indices in custom_config.items(): |
|
|
for head_idx in head_indices: |
|
|
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]: |
|
|
head_matrix = weights[layer_idx, head_idx] |
|
|
selected_tensors.append(head_matrix) |
|
|
|
|
|
if not selected_tensors: |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
weights_stack = torch.stack(selected_tensors, dim=0) |
|
|
visual_matrix = weights_stack.mean(dim=0).numpy() |
|
|
|
|
|
calc_matrix, energy_matrix = self._apply_bidirectional_consensus( |
|
|
weights_stack, violence_level, medfilt_width |
|
|
) |
|
|
|
|
|
return calc_matrix, energy_matrix, visual_matrix |
|
|
|
|
|
def stamps_align_info( |
|
|
self, |
|
|
attention_matrix: torch.Tensor, |
|
|
lyrics_tokens: List[int], |
|
|
total_duration_seconds: float, |
|
|
custom_config: Dict[int, List[int]], |
|
|
return_matrices: bool = False, |
|
|
violence_level: float = 2.0, |
|
|
medfilt_width: int = 1 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Get alignment information from attention matrix. |
|
|
|
|
|
Args: |
|
|
attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames] |
|
|
lyrics_tokens: List of lyrics token IDs |
|
|
total_duration_seconds: Total audio duration in seconds |
|
|
custom_config: Dict mapping layer indices to head indices |
|
|
return_matrices: Whether to return intermediate matrices |
|
|
violence_level: Denoising strength |
|
|
medfilt_width: Median filter width |
|
|
|
|
|
Returns: |
|
|
Dict containing calc_matrix, lyrics_tokens, total_duration_seconds, |
|
|
and optionally energy_matrix and vis_matrix |
|
|
""" |
|
|
calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention( |
|
|
attention_matrix, custom_config, violence_level, medfilt_width |
|
|
) |
|
|
|
|
|
if calc_matrix is None: |
|
|
return { |
|
|
"calc_matrix": None, |
|
|
"lyrics_tokens": lyrics_tokens, |
|
|
"total_duration_seconds": total_duration_seconds, |
|
|
"error": "No valid attention heads found" |
|
|
} |
|
|
|
|
|
return_dict = { |
|
|
"calc_matrix": calc_matrix, |
|
|
"lyrics_tokens": lyrics_tokens, |
|
|
"total_duration_seconds": total_duration_seconds |
|
|
} |
|
|
|
|
|
if return_matrices: |
|
|
return_dict['energy_matrix'] = energy_matrix |
|
|
return_dict['vis_matrix'] = visual_matrix |
|
|
|
|
|
return return_dict |
|
|
|
|
|
def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]: |
|
|
""" |
|
|
Decode tokens incrementally to properly handle multi-byte UTF-8 characters. |
|
|
|
|
|
For Chinese and other multi-byte characters, the tokenizer may split them |
|
|
into multiple byte-level tokens. Decoding each token individually produces |
|
|
invalid UTF-8 sequences (showing as �). This method uses byte-level comparison |
|
|
to correctly track which characters each token contributes. |
|
|
|
|
|
Args: |
|
|
token_ids: List of token IDs |
|
|
|
|
|
Returns: |
|
|
List of decoded text for each token position |
|
|
""" |
|
|
decoded_tokens = [] |
|
|
prev_bytes = b"" |
|
|
|
|
|
for i in range(len(token_ids)): |
|
|
|
|
|
current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False) |
|
|
current_bytes = current_text.encode('utf-8', errors='surrogatepass') |
|
|
|
|
|
|
|
|
if len(current_bytes) >= len(prev_bytes): |
|
|
new_bytes = current_bytes[len(prev_bytes):] |
|
|
|
|
|
try: |
|
|
token_text = new_bytes.decode('utf-8') |
|
|
except UnicodeDecodeError: |
|
|
|
|
|
token_text = "" |
|
|
else: |
|
|
|
|
|
token_text = "" |
|
|
|
|
|
decoded_tokens.append(token_text) |
|
|
prev_bytes = current_bytes |
|
|
|
|
|
return decoded_tokens |
|
|
|
|
|
def token_timestamps( |
|
|
self, |
|
|
calc_matrix: np.ndarray, |
|
|
lyrics_tokens: List[int], |
|
|
total_duration_seconds: float |
|
|
) -> List[TokenTimestamp]: |
|
|
""" |
|
|
Generate per-token timestamps using DTW. |
|
|
|
|
|
Args: |
|
|
calc_matrix: Processed attention matrix [Tokens, Frames] |
|
|
lyrics_tokens: List of token IDs |
|
|
total_duration_seconds: Total audio duration |
|
|
|
|
|
Returns: |
|
|
List of TokenTimestamp objects |
|
|
""" |
|
|
n_frames = calc_matrix.shape[-1] |
|
|
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64)) |
|
|
|
|
|
seconds_per_frame = total_duration_seconds / n_frames |
|
|
alignment_results = [] |
|
|
|
|
|
|
|
|
decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens) |
|
|
|
|
|
for i in range(len(lyrics_tokens)): |
|
|
mask = (text_indices == i) |
|
|
|
|
|
if not np.any(mask): |
|
|
start = alignment_results[-1].end if alignment_results else 0.0 |
|
|
end = start |
|
|
token_conf = 0.0 |
|
|
else: |
|
|
times = time_indices[mask] * seconds_per_frame |
|
|
start = times[0] |
|
|
end = times[-1] |
|
|
token_conf = 0.0 |
|
|
|
|
|
if end < start: |
|
|
end = start |
|
|
|
|
|
alignment_results.append(TokenTimestamp( |
|
|
token_id=lyrics_tokens[i], |
|
|
text=decoded_tokens[i], |
|
|
start=float(start), |
|
|
end=float(end), |
|
|
probability=token_conf |
|
|
)) |
|
|
|
|
|
return alignment_results |
|
|
|
|
|
def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str: |
|
|
""" |
|
|
Decode a sentence by decoding all token IDs together. |
|
|
This avoids UTF-8 encoding issues from joining individual token texts. |
|
|
|
|
|
Args: |
|
|
tokens: List of TokenTimestamp objects |
|
|
|
|
|
Returns: |
|
|
Properly decoded sentence text |
|
|
""" |
|
|
token_ids = [t.token_id for t in tokens] |
|
|
return self.tokenizer.decode(token_ids, skip_special_tokens=False) |
|
|
|
|
|
def sentence_timestamps( |
|
|
self, |
|
|
token_alignment: List[TokenTimestamp] |
|
|
) -> List[SentenceTimestamp]: |
|
|
""" |
|
|
Group token timestamps into sentence timestamps. |
|
|
|
|
|
Args: |
|
|
token_alignment: List of TokenTimestamp objects |
|
|
|
|
|
Returns: |
|
|
List of SentenceTimestamp objects |
|
|
""" |
|
|
results = [] |
|
|
current_tokens = [] |
|
|
|
|
|
for token in token_alignment: |
|
|
current_tokens.append(token) |
|
|
|
|
|
if '\n' in token.text: |
|
|
|
|
|
full_text = self._decode_sentence_from_tokens(current_tokens) |
|
|
|
|
|
if full_text.strip(): |
|
|
valid_scores = [t.probability for t in current_tokens if t.probability > 0] |
|
|
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0 |
|
|
|
|
|
results.append(SentenceTimestamp( |
|
|
text=full_text.strip(), |
|
|
start=round(current_tokens[0].start, 3), |
|
|
end=round(current_tokens[-1].end, 3), |
|
|
tokens=list(current_tokens), |
|
|
confidence=sent_conf |
|
|
)) |
|
|
|
|
|
current_tokens = [] |
|
|
|
|
|
|
|
|
if current_tokens: |
|
|
|
|
|
full_text = self._decode_sentence_from_tokens(current_tokens) |
|
|
if full_text.strip(): |
|
|
valid_scores = [t.probability for t in current_tokens if t.probability > 0] |
|
|
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0 |
|
|
|
|
|
results.append(SentenceTimestamp( |
|
|
text=full_text.strip(), |
|
|
start=round(current_tokens[0].start, 3), |
|
|
end=round(current_tokens[-1].end, 3), |
|
|
tokens=list(current_tokens), |
|
|
confidence=sent_conf |
|
|
)) |
|
|
|
|
|
|
|
|
if results: |
|
|
all_scores = [s.confidence for s in results] |
|
|
min_score = min(all_scores) |
|
|
max_score = max(all_scores) |
|
|
score_range = max_score - min_score |
|
|
|
|
|
if score_range > 1e-9: |
|
|
for s in results: |
|
|
normalized_score = (s.confidence - min_score) / score_range |
|
|
s.confidence = round(normalized_score, 2) |
|
|
else: |
|
|
for s in results: |
|
|
s.confidence = round(s.confidence, 2) |
|
|
|
|
|
return results |
|
|
|
|
|
def format_lrc( |
|
|
self, |
|
|
sentence_timestamps: List[SentenceTimestamp], |
|
|
include_end_time: bool = False |
|
|
) -> str: |
|
|
""" |
|
|
Format sentence timestamps as LRC lyrics format. |
|
|
|
|
|
Args: |
|
|
sentence_timestamps: List of SentenceTimestamp objects |
|
|
include_end_time: Whether to include end time (enhanced LRC format) |
|
|
|
|
|
Returns: |
|
|
LRC formatted string |
|
|
""" |
|
|
lines = [] |
|
|
|
|
|
for sentence in sentence_timestamps: |
|
|
|
|
|
start_minutes = int(sentence.start // 60) |
|
|
start_seconds = sentence.start % 60 |
|
|
|
|
|
if include_end_time: |
|
|
end_minutes = int(sentence.end // 60) |
|
|
end_seconds = sentence.end % 60 |
|
|
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]" |
|
|
else: |
|
|
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]" |
|
|
|
|
|
|
|
|
text = sentence.text |
|
|
|
|
|
lines.append(f"{timestamp}{text}") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
def get_timestamps_and_lrc( |
|
|
self, |
|
|
calc_matrix: np.ndarray, |
|
|
lyrics_tokens: List[int], |
|
|
total_duration_seconds: float |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Convenience method to get both timestamps and LRC in one call. |
|
|
|
|
|
Args: |
|
|
calc_matrix: Processed attention matrix |
|
|
lyrics_tokens: List of token IDs |
|
|
total_duration_seconds: Total audio duration |
|
|
|
|
|
Returns: |
|
|
Dict containing token_timestamps, sentence_timestamps, and lrc_text |
|
|
""" |
|
|
token_stamps = self.token_timestamps( |
|
|
calc_matrix=calc_matrix, |
|
|
lyrics_tokens=lyrics_tokens, |
|
|
total_duration_seconds=total_duration_seconds |
|
|
) |
|
|
|
|
|
sentence_stamps = self.sentence_timestamps(token_stamps) |
|
|
lrc_text = self.format_lrc(sentence_stamps) |
|
|
|
|
|
return { |
|
|
"token_timestamps": token_stamps, |
|
|
"sentence_timestamps": sentence_stamps, |
|
|
"lrc_text": lrc_text |
|
|
} |
|
|
|
|
|
|
|
|
class MusicLyricScorer: |
|
|
""" |
|
|
Scorer class for evaluating lyrics-to-audio alignment quality. |
|
|
|
|
|
Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence) |
|
|
using tensor operations for potential differentiability or GPU acceleration. |
|
|
""" |
|
|
|
|
|
def __init__(self, tokenizer: Any): |
|
|
""" |
|
|
Initialize the aligner. |
|
|
|
|
|
Args: |
|
|
tokenizer: Tokenizer instance (must implement .decode()). |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray: |
|
|
""" |
|
|
Generate a mask distinguishing lyrics (1) from structural tags (0). |
|
|
Uses self.tokenizer to decode tokens. |
|
|
|
|
|
Args: |
|
|
token_ids: List of token IDs. |
|
|
|
|
|
Returns: |
|
|
Numpy array of shape [len(token_ids)] with 1 or 0. |
|
|
""" |
|
|
decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids] |
|
|
mask = np.ones(len(token_ids), dtype=np.int32) |
|
|
in_bracket = False |
|
|
|
|
|
for i, token_str in enumerate(decoded_tokens): |
|
|
if '[' in token_str: |
|
|
in_bracket = True |
|
|
if in_bracket: |
|
|
mask[i] = 0 |
|
|
if ']' in token_str: |
|
|
in_bracket = False |
|
|
mask[i] = 0 |
|
|
return mask |
|
|
|
|
|
def _preprocess_attention( |
|
|
self, |
|
|
attention_matrix: Union[torch.Tensor, np.ndarray], |
|
|
custom_config: Dict[int, List[int]], |
|
|
medfilt_width: int = 1 |
|
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]: |
|
|
""" |
|
|
Extracts and normalizes the attention matrix. |
|
|
|
|
|
Logic V4: Uses Min-Max normalization to highlight energy differences. |
|
|
|
|
|
Args: |
|
|
attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames]. |
|
|
custom_config: Config mapping layers to heads. |
|
|
medfilt_width: Width for median filtering. |
|
|
|
|
|
Returns: |
|
|
Tuple of (calc_matrix, energy_matrix, avg_weights_tensor). |
|
|
""" |
|
|
|
|
|
if not isinstance(attention_matrix, torch.Tensor): |
|
|
weights = torch.tensor(attention_matrix) |
|
|
else: |
|
|
weights = attention_matrix.clone() |
|
|
weights = weights.cpu().float() |
|
|
|
|
|
|
|
|
selected_tensors = [] |
|
|
for layer_idx, head_indices in custom_config.items(): |
|
|
for head_idx in head_indices: |
|
|
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]: |
|
|
selected_tensors.append(weights[layer_idx, head_idx]) |
|
|
|
|
|
if not selected_tensors: |
|
|
return None, None, None |
|
|
|
|
|
weights_stack = torch.stack(selected_tensors, dim=0) |
|
|
|
|
|
|
|
|
avg_weights = weights_stack.mean(dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
energy_tensor = median_filter(avg_weights, filter_width=medfilt_width) |
|
|
energy_matrix = energy_tensor.numpy() |
|
|
|
|
|
e_min, e_max = energy_matrix.min(), energy_matrix.max() |
|
|
|
|
|
if e_max - e_min > 1e-9: |
|
|
energy_matrix = (energy_matrix - e_min) / (e_max - e_min) |
|
|
else: |
|
|
energy_matrix = np.zeros_like(energy_matrix) |
|
|
|
|
|
|
|
|
|
|
|
calc_matrix = energy_matrix ** 2 |
|
|
|
|
|
return calc_matrix, energy_matrix, avg_weights |
|
|
|
|
|
def _compute_alignment_metrics( |
|
|
self, |
|
|
energy_matrix: torch.Tensor, |
|
|
path_coords: torch.Tensor, |
|
|
type_mask: torch.Tensor, |
|
|
time_weight: float = 0.01, |
|
|
overlap_frames: float = 9.0, |
|
|
instrumental_weight: float = 1.0 |
|
|
) -> Tuple[float, float, float]: |
|
|
""" |
|
|
Core metric calculation logic using high-precision Tensor operations. |
|
|
|
|
|
Args: |
|
|
energy_matrix: Normalized energy [Rows, Cols]. |
|
|
path_coords: DTW path coordinates [Steps, 2]. |
|
|
type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags). |
|
|
time_weight: Minimum energy threshold for monotonicity. |
|
|
overlap_frames: Allowed overlap for monotonicity check. |
|
|
instrumental_weight: Weight for non-lyric tokens in confidence calc. |
|
|
|
|
|
Returns: |
|
|
Tuple of (coverage, monotonicity, confidence). |
|
|
""" |
|
|
|
|
|
energy_matrix = energy_matrix.to(dtype=torch.float64) |
|
|
path_coords = path_coords.long() |
|
|
type_mask = type_mask.long() |
|
|
|
|
|
device = energy_matrix.device |
|
|
rows, cols = energy_matrix.shape |
|
|
|
|
|
is_lyrics_row = (type_mask == 1) |
|
|
|
|
|
|
|
|
|
|
|
row_max_energies = energy_matrix.max(dim=1).values |
|
|
total_sung_rows = is_lyrics_row.sum().double() |
|
|
|
|
|
coverage_threshold = 0.1 |
|
|
valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold) |
|
|
valid_sung_rows = valid_sung_mask.sum().double() |
|
|
|
|
|
if total_sung_rows > 0: |
|
|
coverage_score = valid_sung_rows / total_sung_rows |
|
|
else: |
|
|
coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64) |
|
|
|
|
|
|
|
|
|
|
|
col_indices = torch.arange(cols, device=device, dtype=torch.float64) |
|
|
|
|
|
|
|
|
weights = torch.where( |
|
|
energy_matrix > time_weight, |
|
|
energy_matrix, |
|
|
torch.zeros_like(energy_matrix) |
|
|
) |
|
|
|
|
|
sum_w = weights.sum(dim=1) |
|
|
sum_t = (weights * col_indices).sum(dim=1) |
|
|
|
|
|
|
|
|
centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64) |
|
|
valid_w_mask = sum_w > 1e-9 |
|
|
centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask] |
|
|
|
|
|
|
|
|
valid_sequence_mask = is_lyrics_row & (centroids >= 0) |
|
|
sung_centroids = centroids[valid_sequence_mask] |
|
|
|
|
|
cnt = sung_centroids.shape[0] |
|
|
if cnt > 1: |
|
|
curr_c = sung_centroids[:-1] |
|
|
next_c = sung_centroids[1:] |
|
|
|
|
|
|
|
|
non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum() |
|
|
pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64) |
|
|
monotonicity_score = non_decreasing / pairs |
|
|
else: |
|
|
monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64) |
|
|
|
|
|
|
|
|
|
|
|
if path_coords.shape[0] > 0: |
|
|
p_rows = path_coords[:, 0] |
|
|
p_cols = path_coords[:, 1] |
|
|
|
|
|
path_energies = energy_matrix[p_rows, p_cols] |
|
|
step_weights = torch.ones_like(path_energies) |
|
|
|
|
|
|
|
|
is_inst_step = (type_mask[p_rows] == 0) |
|
|
step_weights[is_inst_step] = instrumental_weight |
|
|
|
|
|
total_energy = (path_energies * step_weights).sum() |
|
|
total_steps = step_weights.sum() |
|
|
|
|
|
if total_steps > 0: |
|
|
path_confidence = total_energy / total_steps |
|
|
else: |
|
|
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64) |
|
|
else: |
|
|
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64) |
|
|
|
|
|
return coverage_score.item(), monotonicity_score.item(), path_confidence.item() |
|
|
|
|
|
def lyrics_alignment_info( |
|
|
self, |
|
|
attention_matrix: Union[torch.Tensor, np.ndarray], |
|
|
token_ids: List[int], |
|
|
custom_config: Dict[int, List[int]], |
|
|
return_matrices: bool = False, |
|
|
medfilt_width: int = 1 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Generates alignment path and processed matrices. |
|
|
|
|
|
Args: |
|
|
attention_matrix: Input attention tensor. |
|
|
token_ids: Corresponding token IDs. |
|
|
custom_config: Layer/Head configuration. |
|
|
return_matrices: If True, returns matrices in the output. |
|
|
medfilt_width: Median filter width. |
|
|
|
|
|
Returns: |
|
|
Dict or AlignmentInfo object containing path and masks. |
|
|
""" |
|
|
calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention( |
|
|
attention_matrix, custom_config, medfilt_width |
|
|
) |
|
|
|
|
|
if calc_matrix is None: |
|
|
return { |
|
|
"calc_matrix": None, |
|
|
"error": "No valid attention heads found" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
type_mask = self._generate_token_type_mask(token_ids) |
|
|
|
|
|
|
|
|
if len(type_mask) != energy_matrix.shape[0]: |
|
|
|
|
|
type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32) |
|
|
|
|
|
|
|
|
|
|
|
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32)) |
|
|
path_coords = np.stack([text_indices, time_indices], axis=1) |
|
|
|
|
|
return_dict = { |
|
|
"path_coords": path_coords, |
|
|
"type_mask": type_mask, |
|
|
"energy_matrix": energy_matrix |
|
|
} |
|
|
if return_matrices: |
|
|
return_dict['calc_matrix'] = calc_matrix |
|
|
return_dict['vis_matrix'] = vis_matrix |
|
|
|
|
|
return return_dict |
|
|
|
|
|
def calculate_score( |
|
|
self, |
|
|
energy_matrix: Union[torch.Tensor, np.ndarray], |
|
|
type_mask: Union[torch.Tensor, np.ndarray], |
|
|
path_coords: Union[torch.Tensor, np.ndarray], |
|
|
time_weight: float = 0.01, |
|
|
overlap_frames: float = 9.0, |
|
|
instrumental_weight: float = 1.0 |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Calculates the final alignment score based on pre-computed components. |
|
|
|
|
|
Args: |
|
|
energy_matrix: Processed energy matrix. |
|
|
type_mask: Token type mask. |
|
|
path_coords: DTW path coordinates. |
|
|
time_weight: Minimum energy threshold for monotonicity. |
|
|
overlap_frames: Allowed backward movement frames. |
|
|
instrumental_weight: Weight for non-lyric path steps. |
|
|
|
|
|
Returns: |
|
|
AlignmentScore object containing individual metrics and final score. |
|
|
""" |
|
|
|
|
|
if not isinstance(energy_matrix, torch.Tensor): |
|
|
energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32) |
|
|
|
|
|
device = energy_matrix.device |
|
|
|
|
|
if not isinstance(type_mask, torch.Tensor): |
|
|
type_mask = torch.tensor(type_mask, device=device, dtype=torch.long) |
|
|
else: |
|
|
type_mask = type_mask.to(device=device, dtype=torch.long) |
|
|
|
|
|
if not isinstance(path_coords, torch.Tensor): |
|
|
path_coords = torch.tensor(path_coords, device=device, dtype=torch.long) |
|
|
else: |
|
|
path_coords = path_coords.to(device=device, dtype=torch.long) |
|
|
|
|
|
|
|
|
coverage, monotonicity, confidence = self._compute_alignment_metrics( |
|
|
energy_matrix=energy_matrix, |
|
|
path_coords=path_coords, |
|
|
type_mask=type_mask, |
|
|
time_weight=time_weight, |
|
|
overlap_frames=overlap_frames, |
|
|
instrumental_weight=instrumental_weight |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
final_score = (coverage ** 2) * (monotonicity ** 2) * confidence |
|
|
final_score = float(np.clip(final_score, 0.0, 1.0)) |
|
|
|
|
|
return { |
|
|
"lyrics_score": round(final_score, 4) |
|
|
} |