Spaces:
Running
on
A100
Running
on
A100
add lyc support
Browse files- acestep/dit_alignment_score.py +547 -0
- acestep/gradio_ui/events/__init__.py +35 -11
- acestep/gradio_ui/events/results_handlers.py +130 -22
- acestep/gradio_ui/i18n/en.json +11 -8
- acestep/gradio_ui/i18n/ja.json +11 -8
- acestep/gradio_ui/i18n/zh.json +11 -8
- acestep/gradio_ui/interfaces/result.py +148 -57
- acestep/handler.py +259 -13
- pyproject.toml +3 -3
acestep/dit_alignment_score.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DiT Alignment Score Module
|
| 3 |
+
|
| 4 |
+
This module provides lyrics-to-audio alignment using cross-attention matrices
|
| 5 |
+
from DiT model for generating LRC timestamps.
|
| 6 |
+
|
| 7 |
+
Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
|
| 8 |
+
"""
|
| 9 |
+
import numba
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from dataclasses import dataclass, asdict
|
| 14 |
+
from typing import List, Dict, Any, Optional
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ================= Data Classes =================
|
| 18 |
+
@dataclass
|
| 19 |
+
class TokenTimestamp:
|
| 20 |
+
"""Stores per-token timing information."""
|
| 21 |
+
token_id: int
|
| 22 |
+
text: str
|
| 23 |
+
start: float
|
| 24 |
+
end: float
|
| 25 |
+
probability: float
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class SentenceTimestamp:
|
| 30 |
+
"""Stores per-sentence timing information with token list."""
|
| 31 |
+
text: str
|
| 32 |
+
start: float
|
| 33 |
+
end: float
|
| 34 |
+
tokens: List[TokenTimestamp]
|
| 35 |
+
confidence: float
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ================= DTW Algorithm (Numba Optimized) =================
|
| 39 |
+
@numba.jit(nopython=True)
|
| 40 |
+
def dtw_cpu(x: np.ndarray):
|
| 41 |
+
"""
|
| 42 |
+
Dynamic Time Warping algorithm optimized with Numba.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
x: Cost matrix of shape [N, M]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (text_indices, time_indices) arrays
|
| 49 |
+
"""
|
| 50 |
+
N, M = x.shape
|
| 51 |
+
# Use float32 for memory efficiency
|
| 52 |
+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
| 53 |
+
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
| 54 |
+
cost[0, 0] = 0
|
| 55 |
+
|
| 56 |
+
for j in range(1, M + 1):
|
| 57 |
+
for i in range(1, N + 1):
|
| 58 |
+
c0 = cost[i - 1, j - 1]
|
| 59 |
+
c1 = cost[i - 1, j]
|
| 60 |
+
c2 = cost[i, j - 1]
|
| 61 |
+
|
| 62 |
+
if c0 < c1 and c0 < c2:
|
| 63 |
+
c, t = c0, 0
|
| 64 |
+
elif c1 < c0 and c1 < c2:
|
| 65 |
+
c, t = c1, 1
|
| 66 |
+
else:
|
| 67 |
+
c, t = c2, 2
|
| 68 |
+
|
| 69 |
+
cost[i, j] = x[i - 1, j - 1] + c
|
| 70 |
+
trace[i, j] = t
|
| 71 |
+
|
| 72 |
+
return _backtrace(trace, N, M)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@numba.jit(nopython=True)
|
| 76 |
+
def _backtrace(trace: np.ndarray, N: int, M: int):
|
| 77 |
+
"""
|
| 78 |
+
Optimized backtrace function for DTW.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
trace: Trace matrix of shape (N+1, M+1)
|
| 82 |
+
N, M: Original matrix dimensions
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Path array of shape (2, path_len) - first row is text indices, second is time indices
|
| 86 |
+
"""
|
| 87 |
+
# Boundary handling
|
| 88 |
+
trace[0, :] = 2
|
| 89 |
+
trace[:, 0] = 1
|
| 90 |
+
|
| 91 |
+
# Pre-allocate array, max path length is N+M
|
| 92 |
+
max_path_len = N + M
|
| 93 |
+
path = np.zeros((2, max_path_len), dtype=np.int32)
|
| 94 |
+
|
| 95 |
+
i, j = N, M
|
| 96 |
+
path_idx = max_path_len - 1
|
| 97 |
+
|
| 98 |
+
while i > 0 or j > 0:
|
| 99 |
+
path[0, path_idx] = i - 1 # text index
|
| 100 |
+
path[1, path_idx] = j - 1 # time index
|
| 101 |
+
path_idx -= 1
|
| 102 |
+
|
| 103 |
+
t = trace[i, j]
|
| 104 |
+
if t == 0:
|
| 105 |
+
i -= 1
|
| 106 |
+
j -= 1
|
| 107 |
+
elif t == 1:
|
| 108 |
+
i -= 1
|
| 109 |
+
elif t == 2:
|
| 110 |
+
j -= 1
|
| 111 |
+
else:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
actual_len = max_path_len - path_idx - 1
|
| 115 |
+
return path[:, path_idx + 1:max_path_len]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ================= Utility Functions =================
|
| 119 |
+
def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
|
| 120 |
+
"""
|
| 121 |
+
Apply median filter to tensor.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: Input tensor
|
| 125 |
+
filter_width: Width of median filter
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Filtered tensor
|
| 129 |
+
"""
|
| 130 |
+
pad_width = filter_width // 2
|
| 131 |
+
if x.shape[-1] <= pad_width:
|
| 132 |
+
return x
|
| 133 |
+
if x.ndim == 2:
|
| 134 |
+
x = x[None, :]
|
| 135 |
+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
| 136 |
+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
| 137 |
+
if result.ndim > 2:
|
| 138 |
+
result = result.squeeze(0)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ================= Main Aligner Class =================
|
| 143 |
+
class MusicStampsAligner:
|
| 144 |
+
"""
|
| 145 |
+
Aligner class for generating lyrics timestamps from cross-attention matrices.
|
| 146 |
+
|
| 147 |
+
Uses bidirectional consensus denoising and DTW for alignment.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, tokenizer):
|
| 151 |
+
"""
|
| 152 |
+
Initialize the aligner.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
tokenizer: Text tokenizer for decoding tokens
|
| 156 |
+
"""
|
| 157 |
+
self.tokenizer = tokenizer
|
| 158 |
+
|
| 159 |
+
def _apply_bidirectional_consensus(
|
| 160 |
+
self,
|
| 161 |
+
weights_stack: torch.Tensor,
|
| 162 |
+
violence_level: float,
|
| 163 |
+
medfilt_width: int
|
| 164 |
+
) -> tuple:
|
| 165 |
+
"""
|
| 166 |
+
Core denoising logic using bidirectional consensus.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
weights_stack: Attention weights [Heads, Tokens, Frames]
|
| 170 |
+
violence_level: Denoising strength coefficient
|
| 171 |
+
medfilt_width: Median filter width
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Tuple of (calc_matrix, energy_matrix) as numpy arrays
|
| 175 |
+
"""
|
| 176 |
+
# A. Bidirectional Consensus
|
| 177 |
+
row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
|
| 178 |
+
col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
|
| 179 |
+
processed = row_prob * col_prob
|
| 180 |
+
|
| 181 |
+
# 1. Row suppression (kill horizontal crossing lines)
|
| 182 |
+
row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
|
| 183 |
+
processed = processed - (violence_level * row_medians)
|
| 184 |
+
processed = torch.relu(processed)
|
| 185 |
+
|
| 186 |
+
# 2. Column suppression (kill vertical crossing lines)
|
| 187 |
+
col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
|
| 188 |
+
processed = processed - (violence_level * col_medians)
|
| 189 |
+
processed = torch.relu(processed)
|
| 190 |
+
|
| 191 |
+
# C. Power sharpening
|
| 192 |
+
processed = processed ** 2
|
| 193 |
+
|
| 194 |
+
# Energy matrix for confidence
|
| 195 |
+
energy_matrix = processed.mean(dim=0).cpu().numpy()
|
| 196 |
+
|
| 197 |
+
# D. Z-Score normalization
|
| 198 |
+
std, mean = torch.std_mean(processed, unbiased=False)
|
| 199 |
+
weights_processed = (processed - mean) / (std + 1e-9)
|
| 200 |
+
|
| 201 |
+
# E. Median filtering
|
| 202 |
+
weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
|
| 203 |
+
calc_matrix = weights_processed.mean(dim=0).numpy()
|
| 204 |
+
|
| 205 |
+
return calc_matrix, energy_matrix
|
| 206 |
+
|
| 207 |
+
def _preprocess_attention(
|
| 208 |
+
self,
|
| 209 |
+
attention_matrix: torch.Tensor,
|
| 210 |
+
custom_config: Dict[int, List[int]],
|
| 211 |
+
violence_level: float,
|
| 212 |
+
medfilt_width: int = 7
|
| 213 |
+
) -> tuple:
|
| 214 |
+
"""
|
| 215 |
+
Preprocess attention matrix for alignment.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
|
| 219 |
+
custom_config: Dict mapping layer indices to head indices
|
| 220 |
+
violence_level: Denoising strength
|
| 221 |
+
medfilt_width: Median filter width
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Tuple of (calc_matrix, energy_matrix, visual_matrix)
|
| 225 |
+
"""
|
| 226 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 227 |
+
weights = torch.tensor(attention_matrix)
|
| 228 |
+
else:
|
| 229 |
+
weights = attention_matrix.clone()
|
| 230 |
+
|
| 231 |
+
weights = weights.cpu().float()
|
| 232 |
+
|
| 233 |
+
selected_tensors = []
|
| 234 |
+
for layer_idx, head_indices in custom_config.items():
|
| 235 |
+
for head_idx in head_indices:
|
| 236 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 237 |
+
head_matrix = weights[layer_idx, head_idx]
|
| 238 |
+
selected_tensors.append(head_matrix)
|
| 239 |
+
|
| 240 |
+
if not selected_tensors:
|
| 241 |
+
return None, None, None
|
| 242 |
+
|
| 243 |
+
# Stack selected heads: [Heads, Tokens, Frames]
|
| 244 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 245 |
+
visual_matrix = weights_stack.mean(dim=0).numpy()
|
| 246 |
+
|
| 247 |
+
calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
|
| 248 |
+
weights_stack, violence_level, medfilt_width
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return calc_matrix, energy_matrix, visual_matrix
|
| 252 |
+
|
| 253 |
+
def stamps_align_info(
|
| 254 |
+
self,
|
| 255 |
+
attention_matrix: torch.Tensor,
|
| 256 |
+
lyrics_tokens: List[int],
|
| 257 |
+
total_duration_seconds: float,
|
| 258 |
+
custom_config: Dict[int, List[int]],
|
| 259 |
+
return_matrices: bool = False,
|
| 260 |
+
violence_level: float = 2.0,
|
| 261 |
+
medfilt_width: int = 1
|
| 262 |
+
) -> Dict[str, Any]:
|
| 263 |
+
"""
|
| 264 |
+
Get alignment information from attention matrix.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
|
| 268 |
+
lyrics_tokens: List of lyrics token IDs
|
| 269 |
+
total_duration_seconds: Total audio duration in seconds
|
| 270 |
+
custom_config: Dict mapping layer indices to head indices
|
| 271 |
+
return_matrices: Whether to return intermediate matrices
|
| 272 |
+
violence_level: Denoising strength
|
| 273 |
+
medfilt_width: Median filter width
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
|
| 277 |
+
and optionally energy_matrix and vis_matrix
|
| 278 |
+
"""
|
| 279 |
+
calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
|
| 280 |
+
attention_matrix, custom_config, violence_level, medfilt_width
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if calc_matrix is None:
|
| 284 |
+
return {
|
| 285 |
+
"calc_matrix": None,
|
| 286 |
+
"lyrics_tokens": lyrics_tokens,
|
| 287 |
+
"total_duration_seconds": total_duration_seconds,
|
| 288 |
+
"error": "No valid attention heads found"
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
return_dict = {
|
| 292 |
+
"calc_matrix": calc_matrix,
|
| 293 |
+
"lyrics_tokens": lyrics_tokens,
|
| 294 |
+
"total_duration_seconds": total_duration_seconds
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
if return_matrices:
|
| 298 |
+
return_dict['energy_matrix'] = energy_matrix
|
| 299 |
+
return_dict['vis_matrix'] = visual_matrix
|
| 300 |
+
|
| 301 |
+
return return_dict
|
| 302 |
+
|
| 303 |
+
def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
|
| 304 |
+
"""
|
| 305 |
+
Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
|
| 306 |
+
|
| 307 |
+
For Chinese and other multi-byte characters, the tokenizer may split them
|
| 308 |
+
into multiple byte-level tokens. Decoding each token individually produces
|
| 309 |
+
invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
|
| 310 |
+
to correctly track which characters each token contributes.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
token_ids: List of token IDs
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
List of decoded text for each token position
|
| 317 |
+
"""
|
| 318 |
+
decoded_tokens = []
|
| 319 |
+
prev_bytes = b""
|
| 320 |
+
|
| 321 |
+
for i in range(len(token_ids)):
|
| 322 |
+
# Decode tokens from start to current position
|
| 323 |
+
current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
|
| 324 |
+
current_bytes = current_text.encode('utf-8', errors='surrogatepass')
|
| 325 |
+
|
| 326 |
+
# The contribution of current token is the new bytes added
|
| 327 |
+
if len(current_bytes) >= len(prev_bytes):
|
| 328 |
+
new_bytes = current_bytes[len(prev_bytes):]
|
| 329 |
+
# Try to decode the new bytes; if incomplete, use empty string
|
| 330 |
+
try:
|
| 331 |
+
token_text = new_bytes.decode('utf-8')
|
| 332 |
+
except UnicodeDecodeError:
|
| 333 |
+
# Incomplete UTF-8 sequence, this token doesn't complete a character
|
| 334 |
+
token_text = ""
|
| 335 |
+
else:
|
| 336 |
+
# Edge case: current decode is shorter (shouldn't happen normally)
|
| 337 |
+
token_text = ""
|
| 338 |
+
|
| 339 |
+
decoded_tokens.append(token_text)
|
| 340 |
+
prev_bytes = current_bytes
|
| 341 |
+
|
| 342 |
+
return decoded_tokens
|
| 343 |
+
|
| 344 |
+
def token_timestamps(
|
| 345 |
+
self,
|
| 346 |
+
calc_matrix: np.ndarray,
|
| 347 |
+
lyrics_tokens: List[int],
|
| 348 |
+
total_duration_seconds: float
|
| 349 |
+
) -> List[TokenTimestamp]:
|
| 350 |
+
"""
|
| 351 |
+
Generate per-token timestamps using DTW.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
calc_matrix: Processed attention matrix [Tokens, Frames]
|
| 355 |
+
lyrics_tokens: List of token IDs
|
| 356 |
+
total_duration_seconds: Total audio duration
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
List of TokenTimestamp objects
|
| 360 |
+
"""
|
| 361 |
+
n_frames = calc_matrix.shape[-1]
|
| 362 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
|
| 363 |
+
|
| 364 |
+
seconds_per_frame = total_duration_seconds / n_frames
|
| 365 |
+
alignment_results = []
|
| 366 |
+
|
| 367 |
+
# Use incremental decoding to properly handle multi-byte UTF-8 characters
|
| 368 |
+
decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
|
| 369 |
+
|
| 370 |
+
for i in range(len(lyrics_tokens)):
|
| 371 |
+
mask = (text_indices == i)
|
| 372 |
+
|
| 373 |
+
if not np.any(mask):
|
| 374 |
+
start = alignment_results[-1].end if alignment_results else 0.0
|
| 375 |
+
end = start
|
| 376 |
+
token_conf = 0.0
|
| 377 |
+
else:
|
| 378 |
+
times = time_indices[mask] * seconds_per_frame
|
| 379 |
+
start = times[0]
|
| 380 |
+
end = times[-1]
|
| 381 |
+
token_conf = 0.0
|
| 382 |
+
|
| 383 |
+
if end < start:
|
| 384 |
+
end = start
|
| 385 |
+
|
| 386 |
+
alignment_results.append(TokenTimestamp(
|
| 387 |
+
token_id=lyrics_tokens[i],
|
| 388 |
+
text=decoded_tokens[i],
|
| 389 |
+
start=float(start),
|
| 390 |
+
end=float(end),
|
| 391 |
+
probability=token_conf
|
| 392 |
+
))
|
| 393 |
+
|
| 394 |
+
return alignment_results
|
| 395 |
+
|
| 396 |
+
def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
|
| 397 |
+
"""
|
| 398 |
+
Decode a sentence by decoding all token IDs together.
|
| 399 |
+
This avoids UTF-8 encoding issues from joining individual token texts.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
tokens: List of TokenTimestamp objects
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Properly decoded sentence text
|
| 406 |
+
"""
|
| 407 |
+
token_ids = [t.token_id for t in tokens]
|
| 408 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
| 409 |
+
|
| 410 |
+
def sentence_timestamps(
|
| 411 |
+
self,
|
| 412 |
+
token_alignment: List[TokenTimestamp]
|
| 413 |
+
) -> List[SentenceTimestamp]:
|
| 414 |
+
"""
|
| 415 |
+
Group token timestamps into sentence timestamps.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
token_alignment: List of TokenTimestamp objects
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
List of SentenceTimestamp objects
|
| 422 |
+
"""
|
| 423 |
+
results = []
|
| 424 |
+
current_tokens = []
|
| 425 |
+
|
| 426 |
+
for token in token_alignment:
|
| 427 |
+
current_tokens.append(token)
|
| 428 |
+
|
| 429 |
+
if '\n' in token.text:
|
| 430 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 431 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 432 |
+
|
| 433 |
+
if full_text.strip():
|
| 434 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 435 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 436 |
+
|
| 437 |
+
results.append(SentenceTimestamp(
|
| 438 |
+
text=full_text.strip(),
|
| 439 |
+
start=round(current_tokens[0].start, 3),
|
| 440 |
+
end=round(current_tokens[-1].end, 3),
|
| 441 |
+
tokens=list(current_tokens),
|
| 442 |
+
confidence=sent_conf
|
| 443 |
+
))
|
| 444 |
+
|
| 445 |
+
current_tokens = []
|
| 446 |
+
|
| 447 |
+
# Handle last sentence
|
| 448 |
+
if current_tokens:
|
| 449 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 450 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 451 |
+
if full_text.strip():
|
| 452 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 453 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 454 |
+
|
| 455 |
+
results.append(SentenceTimestamp(
|
| 456 |
+
text=full_text.strip(),
|
| 457 |
+
start=round(current_tokens[0].start, 3),
|
| 458 |
+
end=round(current_tokens[-1].end, 3),
|
| 459 |
+
tokens=list(current_tokens),
|
| 460 |
+
confidence=sent_conf
|
| 461 |
+
))
|
| 462 |
+
|
| 463 |
+
# Normalize confidence scores
|
| 464 |
+
if results:
|
| 465 |
+
all_scores = [s.confidence for s in results]
|
| 466 |
+
min_score = min(all_scores)
|
| 467 |
+
max_score = max(all_scores)
|
| 468 |
+
score_range = max_score - min_score
|
| 469 |
+
|
| 470 |
+
if score_range > 1e-9:
|
| 471 |
+
for s in results:
|
| 472 |
+
normalized_score = (s.confidence - min_score) / score_range
|
| 473 |
+
s.confidence = round(normalized_score, 2)
|
| 474 |
+
else:
|
| 475 |
+
for s in results:
|
| 476 |
+
s.confidence = round(s.confidence, 2)
|
| 477 |
+
|
| 478 |
+
return results
|
| 479 |
+
|
| 480 |
+
def format_lrc(
|
| 481 |
+
self,
|
| 482 |
+
sentence_timestamps: List[SentenceTimestamp],
|
| 483 |
+
include_end_time: bool = False
|
| 484 |
+
) -> str:
|
| 485 |
+
"""
|
| 486 |
+
Format sentence timestamps as LRC lyrics format.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
sentence_timestamps: List of SentenceTimestamp objects
|
| 490 |
+
include_end_time: Whether to include end time (enhanced LRC format)
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
LRC formatted string
|
| 494 |
+
"""
|
| 495 |
+
lines = []
|
| 496 |
+
|
| 497 |
+
for sentence in sentence_timestamps:
|
| 498 |
+
# Convert seconds to mm:ss.xx format
|
| 499 |
+
start_minutes = int(sentence.start // 60)
|
| 500 |
+
start_seconds = sentence.start % 60
|
| 501 |
+
|
| 502 |
+
if include_end_time:
|
| 503 |
+
end_minutes = int(sentence.end // 60)
|
| 504 |
+
end_seconds = sentence.end % 60
|
| 505 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
|
| 506 |
+
else:
|
| 507 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
|
| 508 |
+
|
| 509 |
+
# Clean the text (remove structural tags like [verse], [chorus])
|
| 510 |
+
text = sentence.text
|
| 511 |
+
|
| 512 |
+
lines.append(f"{timestamp}{text}")
|
| 513 |
+
|
| 514 |
+
return "\n".join(lines)
|
| 515 |
+
|
| 516 |
+
def get_timestamps_and_lrc(
|
| 517 |
+
self,
|
| 518 |
+
calc_matrix: np.ndarray,
|
| 519 |
+
lyrics_tokens: List[int],
|
| 520 |
+
total_duration_seconds: float
|
| 521 |
+
) -> Dict[str, Any]:
|
| 522 |
+
"""
|
| 523 |
+
Convenience method to get both timestamps and LRC in one call.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
calc_matrix: Processed attention matrix
|
| 527 |
+
lyrics_tokens: List of token IDs
|
| 528 |
+
total_duration_seconds: Total audio duration
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Dict containing token_timestamps, sentence_timestamps, and lrc_text
|
| 532 |
+
"""
|
| 533 |
+
token_stamps = self.token_timestamps(
|
| 534 |
+
calc_matrix=calc_matrix,
|
| 535 |
+
lyrics_tokens=lyrics_tokens,
|
| 536 |
+
total_duration_seconds=total_duration_seconds
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
sentence_stamps = self.sentence_timestamps(token_stamps)
|
| 540 |
+
lrc_text = self.format_lrc(sentence_stamps)
|
| 541 |
+
|
| 542 |
+
return {
|
| 543 |
+
"token_timestamps": token_stamps,
|
| 544 |
+
"sentence_timestamps": sentence_stamps,
|
| 545 |
+
"lrc_text": lrc_text
|
| 546 |
+
}
|
| 547 |
+
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -358,19 +358,49 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 358 |
)
|
| 359 |
|
| 360 |
# ========== Score Calculation Handlers ==========
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
for btn_idx in range(1, 9):
|
| 362 |
results_section[f"score_btn_{btn_idx}"].click(
|
| 363 |
-
fn=
|
| 364 |
-
llm_handler, sample_idx, scale, batch_idx, queue
|
| 365 |
-
),
|
| 366 |
inputs=[
|
| 367 |
-
gr.State(value=btn_idx),
|
| 368 |
generation_section["score_scale"],
|
| 369 |
results_section["current_batch_index"],
|
| 370 |
results_section["batch_queue"],
|
| 371 |
],
|
| 372 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
def generation_wrapper(*args):
|
| 375 |
yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
|
| 376 |
# ========== Generation Handler ==========
|
|
@@ -438,12 +468,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 438 |
results_section["generation_info"],
|
| 439 |
results_section["status_output"],
|
| 440 |
generation_section["seed"],
|
| 441 |
-
results_section["align_score_1"],
|
| 442 |
-
results_section["align_text_1"],
|
| 443 |
-
results_section["align_plot_1"],
|
| 444 |
-
results_section["align_score_2"],
|
| 445 |
-
results_section["align_text_2"],
|
| 446 |
-
results_section["align_plot_2"],
|
| 447 |
results_section["score_display_1"],
|
| 448 |
results_section["score_display_2"],
|
| 449 |
results_section["score_display_3"],
|
|
|
|
| 358 |
)
|
| 359 |
|
| 360 |
# ========== Score Calculation Handlers ==========
|
| 361 |
+
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 362 |
+
def make_score_handler(idx):
|
| 363 |
+
return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
|
| 364 |
+
llm_handler, idx, scale, batch_idx, queue
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
for btn_idx in range(1, 9):
|
| 368 |
results_section[f"score_btn_{btn_idx}"].click(
|
| 369 |
+
fn=make_score_handler(btn_idx),
|
|
|
|
|
|
|
| 370 |
inputs=[
|
|
|
|
| 371 |
generation_section["score_scale"],
|
| 372 |
results_section["current_batch_index"],
|
| 373 |
results_section["batch_queue"],
|
| 374 |
],
|
| 375 |
+
outputs=[
|
| 376 |
+
results_section[f"score_display_{btn_idx}"],
|
| 377 |
+
results_section[f"details_accordion_{btn_idx}"],
|
| 378 |
+
results_section["batch_queue"]
|
| 379 |
+
]
|
| 380 |
)
|
| 381 |
+
|
| 382 |
+
# ========== LRC Timestamp Handlers ==========
|
| 383 |
+
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 384 |
+
def make_lrc_handler(idx):
|
| 385 |
+
return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
|
| 386 |
+
dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
for btn_idx in range(1, 9):
|
| 390 |
+
results_section[f"lrc_btn_{btn_idx}"].click(
|
| 391 |
+
fn=make_lrc_handler(btn_idx),
|
| 392 |
+
inputs=[
|
| 393 |
+
results_section["current_batch_index"],
|
| 394 |
+
results_section["batch_queue"],
|
| 395 |
+
generation_section["vocal_language"],
|
| 396 |
+
generation_section["inference_steps"],
|
| 397 |
+
],
|
| 398 |
+
outputs=[
|
| 399 |
+
results_section[f"lrc_display_{btn_idx}"],
|
| 400 |
+
results_section[f"details_accordion_{btn_idx}"]
|
| 401 |
+
]
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
def generation_wrapper(*args):
|
| 405 |
yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
|
| 406 |
# ========== Generation Handler ==========
|
|
|
|
| 468 |
results_section["generation_info"],
|
| 469 |
results_section["status_output"],
|
| 470 |
generation_section["seed"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
results_section["score_display_1"],
|
| 472 |
results_section["score_display_2"],
|
| 473 |
results_section["score_display_3"],
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -141,6 +141,7 @@ def store_batch_in_queue(
|
|
| 141 |
batch_size=2,
|
| 142 |
generation_params=None,
|
| 143 |
lm_generated_metadata=None,
|
|
|
|
| 144 |
status="completed"
|
| 145 |
):
|
| 146 |
"""Store batch results in queue with ALL generation parameters
|
|
@@ -152,6 +153,7 @@ def store_batch_in_queue(
|
|
| 152 |
batch_size: Batch size used for this batch
|
| 153 |
generation_params: Complete dictionary of ALL generation parameters used
|
| 154 |
lm_generated_metadata: LM-generated metadata for scoring (optional)
|
|
|
|
| 155 |
"""
|
| 156 |
batch_queue[batch_index] = {
|
| 157 |
"status": status,
|
|
@@ -164,6 +166,7 @@ def store_batch_in_queue(
|
|
| 164 |
"batch_size": batch_size, # Store batch size
|
| 165 |
"generation_params": generation_params if generation_params else {}, # Store ALL parameters
|
| 166 |
"lm_generated_metadata": lm_generated_metadata, # Store LM metadata for scoring
|
|
|
|
| 167 |
"timestamp": datetime.datetime.now().isoformat()
|
| 168 |
}
|
| 169 |
return batch_queue
|
|
@@ -355,12 +358,6 @@ def generate_with_progress(
|
|
| 355 |
audio_conversion_start_time = time_module.time()
|
| 356 |
total_auto_score_time = 0.0
|
| 357 |
|
| 358 |
-
align_score_1 = ""
|
| 359 |
-
align_text_1 = ""
|
| 360 |
-
align_plot_1 = None
|
| 361 |
-
align_score_2 = ""
|
| 362 |
-
align_text_2 = ""
|
| 363 |
-
align_plot_2 = None
|
| 364 |
updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
|
| 365 |
|
| 366 |
# Build initial generation_info (will be updated with post-processing times at the end)
|
|
@@ -373,7 +370,7 @@ def generate_with_progress(
|
|
| 373 |
)
|
| 374 |
|
| 375 |
if not result.success:
|
| 376 |
-
yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) *
|
| 377 |
return
|
| 378 |
|
| 379 |
audios = result.audios
|
|
@@ -421,8 +418,6 @@ def generate_with_progress(
|
|
| 421 |
generation_info,
|
| 422 |
status_message,
|
| 423 |
seed_value_for_ui,
|
| 424 |
-
# Align plot placeholders (assume no need to update in real time)
|
| 425 |
-
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 426 |
# Scores
|
| 427 |
scores_ui_updates[0], scores_ui_updates[1], scores_ui_updates[2], scores_ui_updates[3], scores_ui_updates[4], scores_ui_updates[5], scores_ui_updates[6], scores_ui_updates[7],
|
| 428 |
updated_audio_codes,
|
|
@@ -431,6 +426,7 @@ def generate_with_progress(
|
|
| 431 |
audio_codes_ui_updates[4], audio_codes_ui_updates[5], audio_codes_ui_updates[6], audio_codes_ui_updates[7],
|
| 432 |
lm_generated_metadata,
|
| 433 |
is_format_caption,
|
|
|
|
| 434 |
)
|
| 435 |
else:
|
| 436 |
# If i exceeds the generated count (e.g., batch=2, i=2..7), do not yield
|
|
@@ -467,7 +463,6 @@ def generate_with_progress(
|
|
| 467 |
generation_info,
|
| 468 |
"Generation Complete",
|
| 469 |
seed_value_for_ui,
|
| 470 |
-
align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2,
|
| 471 |
final_scores_list[0], final_scores_list[1], final_scores_list[2], final_scores_list[3],
|
| 472 |
final_scores_list[4], final_scores_list[5], final_scores_list[6], final_scores_list[7],
|
| 473 |
updated_audio_codes,
|
|
@@ -475,6 +470,7 @@ def generate_with_progress(
|
|
| 475 |
final_codes_list[4], final_codes_list[5], final_codes_list[6], final_codes_list[7],
|
| 476 |
lm_generated_metadata,
|
| 477 |
is_format_caption,
|
|
|
|
| 478 |
)
|
| 479 |
|
| 480 |
|
|
@@ -595,7 +591,7 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
|
|
| 595 |
batch_queue: Batch queue containing historical generation data
|
| 596 |
"""
|
| 597 |
if current_batch_index not in batch_queue:
|
| 598 |
-
return
|
| 599 |
|
| 600 |
batch_data = batch_queue[current_batch_index]
|
| 601 |
params = batch_data.get("generation_params", {})
|
|
@@ -642,7 +638,106 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
|
|
| 642 |
batch_queue[current_batch_index]["scores"] = [""] * 8
|
| 643 |
batch_queue[current_batch_index]["scores"][sample_idx - 1] = score_display
|
| 644 |
|
| 645 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
|
| 647 |
|
| 648 |
def capture_current_params(
|
|
@@ -758,7 +853,9 @@ def generate_with_batch_management(
|
|
| 758 |
final_result_from_inner = partial_result
|
| 759 |
# current_batch_index, total_batches, batch_queue, next_params,
|
| 760 |
# batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
|
| 761 |
-
|
|
|
|
|
|
|
| 762 |
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 763 |
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 764 |
)
|
|
@@ -766,21 +863,23 @@ def generate_with_batch_management(
|
|
| 766 |
all_audio_paths = result[8]
|
| 767 |
|
| 768 |
if all_audio_paths is None:
|
| 769 |
-
|
| 770 |
-
|
|
|
|
| 771 |
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 772 |
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 773 |
)
|
| 774 |
return
|
| 775 |
|
| 776 |
# Extract results from generation (使用 result 下标访问)
|
|
|
|
| 777 |
generation_info = result[9]
|
| 778 |
seed_value_for_ui = result[11]
|
| 779 |
-
lm_generated_metadata = result[
|
| 780 |
|
| 781 |
# Extract codes
|
| 782 |
-
generated_codes_single = result[26
|
| 783 |
-
generated_codes_batch = [result[
|
| 784 |
|
| 785 |
# Determine which codes to store based on mode
|
| 786 |
if allow_lm_batch and batch_size_input >= 2:
|
|
@@ -839,6 +938,9 @@ def generate_with_batch_management(
|
|
| 839 |
next_params["text2music_audio_code_string"] = ""
|
| 840 |
next_params["random_seed_checkbox"] = True
|
| 841 |
|
|
|
|
|
|
|
|
|
|
| 842 |
# Store current batch in queue
|
| 843 |
batch_queue = store_batch_in_queue(
|
| 844 |
batch_queue,
|
|
@@ -851,6 +953,7 @@ def generate_with_batch_management(
|
|
| 851 |
batch_size=int(batch_size_input),
|
| 852 |
generation_params=saved_params,
|
| 853 |
lm_generated_metadata=lm_generated_metadata,
|
|
|
|
| 854 |
status="completed"
|
| 855 |
)
|
| 856 |
|
|
@@ -870,7 +973,9 @@ def generate_with_batch_management(
|
|
| 870 |
|
| 871 |
# 4. Yield final result (includes Batch UI updates)
|
| 872 |
# The result here is already a tuple structure
|
| 873 |
-
|
|
|
|
|
|
|
| 874 |
current_batch_index,
|
| 875 |
total_batches,
|
| 876 |
batch_queue,
|
|
@@ -1040,14 +1145,15 @@ def generate_next_batch_background(
|
|
| 1040 |
final_result = partial_result
|
| 1041 |
|
| 1042 |
# Extract results from final_result
|
|
|
|
| 1043 |
all_audio_paths = final_result[8] # generated_audio_batch
|
| 1044 |
generation_info = final_result[9]
|
| 1045 |
seed_value_for_ui = final_result[11]
|
| 1046 |
-
lm_generated_metadata = final_result[
|
| 1047 |
|
| 1048 |
# Extract codes
|
| 1049 |
-
generated_codes_single = final_result[26
|
| 1050 |
-
generated_codes_batch = [final_result[
|
| 1051 |
|
| 1052 |
# Determine which codes to store
|
| 1053 |
batch_size = params.get("batch_size_input", 2)
|
|
@@ -1070,6 +1176,7 @@ def generate_next_batch_background(
|
|
| 1070 |
logger.info(f" - codes_to_store: STRING with {len(codes_to_store) if codes_to_store else 0} chars")
|
| 1071 |
|
| 1072 |
# Store next batch in queue with codes, batch settings, and ALL generation params
|
|
|
|
| 1073 |
batch_queue = store_batch_in_queue(
|
| 1074 |
batch_queue,
|
| 1075 |
next_batch_idx,
|
|
@@ -1081,6 +1188,7 @@ def generate_next_batch_background(
|
|
| 1081 |
batch_size=int(batch_size),
|
| 1082 |
generation_params=params,
|
| 1083 |
lm_generated_metadata=lm_generated_metadata,
|
|
|
|
| 1084 |
status="completed"
|
| 1085 |
)
|
| 1086 |
|
|
|
|
| 141 |
batch_size=2,
|
| 142 |
generation_params=None,
|
| 143 |
lm_generated_metadata=None,
|
| 144 |
+
extra_outputs=None,
|
| 145 |
status="completed"
|
| 146 |
):
|
| 147 |
"""Store batch results in queue with ALL generation parameters
|
|
|
|
| 153 |
batch_size: Batch size used for this batch
|
| 154 |
generation_params: Complete dictionary of ALL generation parameters used
|
| 155 |
lm_generated_metadata: LM-generated metadata for scoring (optional)
|
| 156 |
+
extra_outputs: Dictionary containing pred_latents, encoder_hidden_states, etc. for LRC generation
|
| 157 |
"""
|
| 158 |
batch_queue[batch_index] = {
|
| 159 |
"status": status,
|
|
|
|
| 166 |
"batch_size": batch_size, # Store batch size
|
| 167 |
"generation_params": generation_params if generation_params else {}, # Store ALL parameters
|
| 168 |
"lm_generated_metadata": lm_generated_metadata, # Store LM metadata for scoring
|
| 169 |
+
"extra_outputs": extra_outputs if extra_outputs else {}, # Store extra outputs for LRC generation
|
| 170 |
"timestamp": datetime.datetime.now().isoformat()
|
| 171 |
}
|
| 172 |
return batch_queue
|
|
|
|
| 358 |
audio_conversion_start_time = time_module.time()
|
| 359 |
total_auto_score_time = 0.0
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
|
| 362 |
|
| 363 |
# Build initial generation_info (will be updated with post-processing times at the end)
|
|
|
|
| 370 |
)
|
| 371 |
|
| 372 |
if not result.success:
|
| 373 |
+
yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 20 + (None,) # +1 for extra_outputs
|
| 374 |
return
|
| 375 |
|
| 376 |
audios = result.audios
|
|
|
|
| 418 |
generation_info,
|
| 419 |
status_message,
|
| 420 |
seed_value_for_ui,
|
|
|
|
|
|
|
| 421 |
# Scores
|
| 422 |
scores_ui_updates[0], scores_ui_updates[1], scores_ui_updates[2], scores_ui_updates[3], scores_ui_updates[4], scores_ui_updates[5], scores_ui_updates[6], scores_ui_updates[7],
|
| 423 |
updated_audio_codes,
|
|
|
|
| 426 |
audio_codes_ui_updates[4], audio_codes_ui_updates[5], audio_codes_ui_updates[6], audio_codes_ui_updates[7],
|
| 427 |
lm_generated_metadata,
|
| 428 |
is_format_caption,
|
| 429 |
+
None, # Placeholder for extra_outputs (only filled in final yield)
|
| 430 |
)
|
| 431 |
else:
|
| 432 |
# If i exceeds the generated count (e.g., batch=2, i=2..7), do not yield
|
|
|
|
| 463 |
generation_info,
|
| 464 |
"Generation Complete",
|
| 465 |
seed_value_for_ui,
|
|
|
|
| 466 |
final_scores_list[0], final_scores_list[1], final_scores_list[2], final_scores_list[3],
|
| 467 |
final_scores_list[4], final_scores_list[5], final_scores_list[6], final_scores_list[7],
|
| 468 |
updated_audio_codes,
|
|
|
|
| 470 |
final_codes_list[4], final_codes_list[5], final_codes_list[6], final_codes_list[7],
|
| 471 |
lm_generated_metadata,
|
| 472 |
is_format_caption,
|
| 473 |
+
result.extra_outputs, # extra_outputs for LRC generation
|
| 474 |
)
|
| 475 |
|
| 476 |
|
|
|
|
| 591 |
batch_queue: Batch queue containing historical generation data
|
| 592 |
"""
|
| 593 |
if current_batch_index not in batch_queue:
|
| 594 |
+
return gr.skip(), gr.skip(), batch_queue
|
| 595 |
|
| 596 |
batch_data = batch_queue[current_batch_index]
|
| 597 |
params = batch_data.get("generation_params", {})
|
|
|
|
| 638 |
batch_queue[current_batch_index]["scores"] = [""] * 8
|
| 639 |
batch_queue[current_batch_index]["scores"][sample_idx - 1] = score_display
|
| 640 |
|
| 641 |
+
# Return: score_display (content + visible), accordion visible, batch_queue
|
| 642 |
+
return (
|
| 643 |
+
gr.update(value=score_display, visible=True), # score_display with content
|
| 644 |
+
gr.update(visible=True), # details_accordion
|
| 645 |
+
batch_queue
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_queue, vocal_language, inference_steps):
|
| 650 |
+
"""
|
| 651 |
+
Generate LRC timestamps for a specific audio sample.
|
| 652 |
+
|
| 653 |
+
This function retrieves cached generation data from batch_queue and calls
|
| 654 |
+
the handler's get_lyric_timestamp method to generate LRC format lyrics.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
dit_handler: DiT handler instance with get_lyric_timestamp method
|
| 658 |
+
sample_idx: Which sample to generate LRC for (1-8)
|
| 659 |
+
current_batch_index: Current batch index in batch_queue
|
| 660 |
+
batch_queue: Dictionary storing all batch generation data
|
| 661 |
+
vocal_language: Language code for lyrics
|
| 662 |
+
inference_steps: Number of inference steps used in generation
|
| 663 |
+
|
| 664 |
+
Returns:
|
| 665 |
+
LRC formatted string or error message
|
| 666 |
+
"""
|
| 667 |
+
import torch
|
| 668 |
+
|
| 669 |
+
if current_batch_index not in batch_queue:
|
| 670 |
+
return gr.skip(), gr.skip()
|
| 671 |
+
|
| 672 |
+
batch_data = batch_queue[current_batch_index]
|
| 673 |
+
extra_outputs = batch_data.get("extra_outputs", {})
|
| 674 |
+
|
| 675 |
+
# Check if required data is available
|
| 676 |
+
if not extra_outputs:
|
| 677 |
+
return gr.update(value=t("messages.lrc_no_extra_outputs"), visible=True), gr.update(visible=True)
|
| 678 |
+
|
| 679 |
+
pred_latents = extra_outputs.get("pred_latents")
|
| 680 |
+
encoder_hidden_states = extra_outputs.get("encoder_hidden_states")
|
| 681 |
+
encoder_attention_mask = extra_outputs.get("encoder_attention_mask")
|
| 682 |
+
context_latents = extra_outputs.get("context_latents")
|
| 683 |
+
lyric_token_idss = extra_outputs.get("lyric_token_idss")
|
| 684 |
+
|
| 685 |
+
if any(x is None for x in [pred_latents, encoder_hidden_states, encoder_attention_mask, context_latents, lyric_token_idss]):
|
| 686 |
+
return gr.update(value=t("messages.lrc_missing_tensors"), visible=True), gr.update(visible=True)
|
| 687 |
+
|
| 688 |
+
# Adjust sample_idx to 0-based
|
| 689 |
+
sample_idx_0based = sample_idx - 1
|
| 690 |
+
|
| 691 |
+
# Check if sample exists in batch
|
| 692 |
+
batch_size = pred_latents.shape[0]
|
| 693 |
+
if sample_idx_0based >= batch_size:
|
| 694 |
+
return gr.update(value=t("messages.lrc_sample_not_exist"), visible=True), gr.update(visible=True)
|
| 695 |
+
|
| 696 |
+
# Extract the specific sample's data
|
| 697 |
+
try:
|
| 698 |
+
# Get audio duration from batch data
|
| 699 |
+
params = batch_data.get("generation_params", {})
|
| 700 |
+
audio_duration = params.get("audio_duration", -1)
|
| 701 |
+
|
| 702 |
+
# Calculate duration from latents if not specified
|
| 703 |
+
if audio_duration is None or audio_duration <= 0:
|
| 704 |
+
# latent_length * frames_per_second_ratio ≈ audio_duration
|
| 705 |
+
# Assuming 25 Hz latent rate: latent_length / 25 = duration
|
| 706 |
+
latent_length = pred_latents.shape[1]
|
| 707 |
+
audio_duration = latent_length / 25.0 # 25 Hz latent rate
|
| 708 |
+
|
| 709 |
+
# Get the sample's data (keep batch dimension for handler)
|
| 710 |
+
sample_pred_latent = pred_latents[sample_idx_0based:sample_idx_0based+1]
|
| 711 |
+
sample_encoder_hidden_states = encoder_hidden_states[sample_idx_0based:sample_idx_0based+1]
|
| 712 |
+
sample_encoder_attention_mask = encoder_attention_mask[sample_idx_0based:sample_idx_0based+1]
|
| 713 |
+
sample_context_latents = context_latents[sample_idx_0based:sample_idx_0based+1]
|
| 714 |
+
sample_lyric_token_ids = lyric_token_idss[sample_idx_0based:sample_idx_0based+1]
|
| 715 |
+
|
| 716 |
+
# Call handler to generate timestamps
|
| 717 |
+
result = dit_handler.get_lyric_timestamp(
|
| 718 |
+
pred_latent=sample_pred_latent,
|
| 719 |
+
encoder_hidden_states=sample_encoder_hidden_states,
|
| 720 |
+
encoder_attention_mask=sample_encoder_attention_mask,
|
| 721 |
+
context_latents=sample_context_latents,
|
| 722 |
+
lyric_token_ids=sample_lyric_token_ids,
|
| 723 |
+
total_duration_seconds=float(audio_duration),
|
| 724 |
+
vocal_language=vocal_language or "en",
|
| 725 |
+
inference_steps=int(inference_steps),
|
| 726 |
+
seed=42, # Use fixed seed for reproducibility
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
if result.get("success"):
|
| 730 |
+
lrc_text = result.get("lrc_text", "")
|
| 731 |
+
if not lrc_text:
|
| 732 |
+
return gr.update(value=t("messages.lrc_empty_result"), visible=True), gr.update(visible=True)
|
| 733 |
+
return gr.update(value=lrc_text, visible=True), gr.update(visible=True)
|
| 734 |
+
else:
|
| 735 |
+
error_msg = result.get("error", "Unknown error")
|
| 736 |
+
return gr.update(value=f"❌ {error_msg}", visible=True), gr.update(visible=True)
|
| 737 |
+
|
| 738 |
+
except Exception as e:
|
| 739 |
+
logger.exception("[generate_lrc_handler] Error generating LRC")
|
| 740 |
+
return gr.update(value=f"❌ Error: {str(e)}", visible=True), gr.update(visible=True)
|
| 741 |
|
| 742 |
|
| 743 |
def capture_current_params(
|
|
|
|
| 853 |
final_result_from_inner = partial_result
|
| 854 |
# current_batch_index, total_batches, batch_queue, next_params,
|
| 855 |
# batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
|
| 856 |
+
# Slice off extra_outputs (last item) before re-yielding to UI
|
| 857 |
+
ui_result = partial_result[:-1] if len(partial_result) > 31 else partial_result
|
| 858 |
+
yield ui_result + (
|
| 859 |
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 860 |
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 861 |
)
|
|
|
|
| 863 |
all_audio_paths = result[8]
|
| 864 |
|
| 865 |
if all_audio_paths is None:
|
| 866 |
+
# Slice off extra_outputs before yielding to UI
|
| 867 |
+
ui_result = result[:-1] if len(result) > 31 else result
|
| 868 |
+
yield ui_result + (
|
| 869 |
gr.skip(), gr.skip(), gr.skip(), gr.skip(),
|
| 870 |
gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
|
| 871 |
)
|
| 872 |
return
|
| 873 |
|
| 874 |
# Extract results from generation (使用 result 下标访问)
|
| 875 |
+
# New indices after removing 6 align_* items (was 12-17, now shifted down by 6)
|
| 876 |
generation_info = result[9]
|
| 877 |
seed_value_for_ui = result[11]
|
| 878 |
+
lm_generated_metadata = result[29] # was 35, now 29
|
| 879 |
|
| 880 |
# Extract codes
|
| 881 |
+
generated_codes_single = result[20] # was 26, now 20
|
| 882 |
+
generated_codes_batch = [result[21], result[22], result[23], result[24], result[25], result[26], result[27], result[28]] # was 27-34, now 21-28
|
| 883 |
|
| 884 |
# Determine which codes to store based on mode
|
| 885 |
if allow_lm_batch and batch_size_input >= 2:
|
|
|
|
| 938 |
next_params["text2music_audio_code_string"] = ""
|
| 939 |
next_params["random_seed_checkbox"] = True
|
| 940 |
|
| 941 |
+
# Extract extra_outputs from result tuple (index 31)
|
| 942 |
+
extra_outputs_from_result = result[31] if len(result) > 31 else {}
|
| 943 |
+
|
| 944 |
# Store current batch in queue
|
| 945 |
batch_queue = store_batch_in_queue(
|
| 946 |
batch_queue,
|
|
|
|
| 953 |
batch_size=int(batch_size_input),
|
| 954 |
generation_params=saved_params,
|
| 955 |
lm_generated_metadata=lm_generated_metadata,
|
| 956 |
+
extra_outputs=extra_outputs_from_result, # Store extra outputs for LRC generation
|
| 957 |
status="completed"
|
| 958 |
)
|
| 959 |
|
|
|
|
| 973 |
|
| 974 |
# 4. Yield final result (includes Batch UI updates)
|
| 975 |
# The result here is already a tuple structure
|
| 976 |
+
# Slice off extra_outputs (last item) before yielding to UI - it's already stored in batch_queue
|
| 977 |
+
ui_result = result[:-1] if len(result) > 31 else result
|
| 978 |
+
yield ui_result + (
|
| 979 |
current_batch_index,
|
| 980 |
total_batches,
|
| 981 |
batch_queue,
|
|
|
|
| 1145 |
final_result = partial_result
|
| 1146 |
|
| 1147 |
# Extract results from final_result
|
| 1148 |
+
# Indices shifted by -6 after removing align_* items
|
| 1149 |
all_audio_paths = final_result[8] # generated_audio_batch
|
| 1150 |
generation_info = final_result[9]
|
| 1151 |
seed_value_for_ui = final_result[11]
|
| 1152 |
+
lm_generated_metadata = final_result[29] # was 35, now 29
|
| 1153 |
|
| 1154 |
# Extract codes
|
| 1155 |
+
generated_codes_single = final_result[20] # was 26, now 20
|
| 1156 |
+
generated_codes_batch = [final_result[21], final_result[22], final_result[23], final_result[24], final_result[25], final_result[26], final_result[27], final_result[28]] # was 27-34, now 21-28
|
| 1157 |
|
| 1158 |
# Determine which codes to store
|
| 1159 |
batch_size = params.get("batch_size_input", 2)
|
|
|
|
| 1176 |
logger.info(f" - codes_to_store: STRING with {len(codes_to_store) if codes_to_store else 0} chars")
|
| 1177 |
|
| 1178 |
# Store next batch in queue with codes, batch settings, and ALL generation params
|
| 1179 |
+
# Note: extra_outputs not available for background batches (LRC not supported for auto-gen batches)
|
| 1180 |
batch_queue = store_batch_in_queue(
|
| 1181 |
batch_queue,
|
| 1182 |
next_batch_idx,
|
|
|
|
| 1188 |
batch_size=int(batch_size),
|
| 1189 |
generation_params=params,
|
| 1190 |
lm_generated_metadata=lm_generated_metadata,
|
| 1191 |
+
extra_outputs=None, # Not available for background batches
|
| 1192 |
status="completed"
|
| 1193 |
)
|
| 1194 |
|
acestep/gradio_ui/i18n/en.json
CHANGED
|
@@ -148,8 +148,6 @@
|
|
| 148 |
"cover_strength_info": "Control how many denoising steps use cover mode",
|
| 149 |
"score_sensitivity_label": "Quality Score Sensitivity",
|
| 150 |
"score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
|
| 151 |
-
"attention_focus_label": "Output Attention Focus Score (disabled)",
|
| 152 |
-
"attention_focus_info": "Output attention focus score analysis",
|
| 153 |
"think_label": "Think",
|
| 154 |
"parallel_thinking_label": "ParallelThinking",
|
| 155 |
"generate_btn": "🎵 Generate Music",
|
|
@@ -162,8 +160,12 @@
|
|
| 162 |
"send_to_src_btn": "🔗 Send To Src Audio",
|
| 163 |
"save_btn": "💾 Save",
|
| 164 |
"score_btn": "📊 Score",
|
|
|
|
| 165 |
"quality_score_label": "Quality Score (Sample {n})",
|
| 166 |
"quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
|
|
|
|
|
|
|
|
|
|
| 167 |
"generation_status": "Generation Status",
|
| 168 |
"current_batch": "Current Batch",
|
| 169 |
"batch_indicator": "Batch {current} / {total}",
|
|
@@ -173,11 +175,7 @@
|
|
| 173 |
"restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
|
| 174 |
"batch_results_title": "📁 Batch Results & Generation Details",
|
| 175 |
"all_files_label": "📁 All Generated Files (Download)",
|
| 176 |
-
"generation_details": "Generation Details"
|
| 177 |
-
"attention_analysis": "⚖️ Attention Focus Score Analysis",
|
| 178 |
-
"attention_score": "Attention Focus Score (Sample {n})",
|
| 179 |
-
"lyric_timestamps": "Lyric Timestamps (Sample {n})",
|
| 180 |
-
"attention_heatmap": "Attention Focus Score Heatmap (Sample {n})"
|
| 181 |
},
|
| 182 |
"messages": {
|
| 183 |
"no_audio_to_save": "❌ No audio to save",
|
|
@@ -206,6 +204,11 @@
|
|
| 206 |
"scoring_failed": "❌ Error: Batch data not found",
|
| 207 |
"no_codes": "❌ No audio codes available. Please generate music first.",
|
| 208 |
"score_failed": "❌ Scoring failed: {error}",
|
| 209 |
-
"score_error": "❌ Error calculating score: {error}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
}
|
| 211 |
}
|
|
|
|
| 148 |
"cover_strength_info": "Control how many denoising steps use cover mode",
|
| 149 |
"score_sensitivity_label": "Quality Score Sensitivity",
|
| 150 |
"score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
|
|
|
|
|
|
|
| 151 |
"think_label": "Think",
|
| 152 |
"parallel_thinking_label": "ParallelThinking",
|
| 153 |
"generate_btn": "🎵 Generate Music",
|
|
|
|
| 160 |
"send_to_src_btn": "🔗 Send To Src Audio",
|
| 161 |
"save_btn": "💾 Save",
|
| 162 |
"score_btn": "📊 Score",
|
| 163 |
+
"lrc_btn": "🎵 LRC",
|
| 164 |
"quality_score_label": "Quality Score (Sample {n})",
|
| 165 |
"quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
|
| 166 |
+
"lrc_label": "Lyrics Timestamps (Sample {n})",
|
| 167 |
+
"lrc_placeholder": "Click 'LRC' to generate timestamps",
|
| 168 |
+
"details_accordion": "📊 Score & LRC",
|
| 169 |
"generation_status": "Generation Status",
|
| 170 |
"current_batch": "Current Batch",
|
| 171 |
"batch_indicator": "Batch {current} / {total}",
|
|
|
|
| 175 |
"restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
|
| 176 |
"batch_results_title": "📁 Batch Results & Generation Details",
|
| 177 |
"all_files_label": "📁 All Generated Files (Download)",
|
| 178 |
+
"generation_details": "Generation Details"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
},
|
| 180 |
"messages": {
|
| 181 |
"no_audio_to_save": "❌ No audio to save",
|
|
|
|
| 204 |
"scoring_failed": "❌ Error: Batch data not found",
|
| 205 |
"no_codes": "❌ No audio codes available. Please generate music first.",
|
| 206 |
"score_failed": "❌ Scoring failed: {error}",
|
| 207 |
+
"score_error": "❌ Error calculating score: {error}",
|
| 208 |
+
"lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
|
| 209 |
+
"lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
|
| 210 |
+
"lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
|
| 211 |
+
"lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
|
| 212 |
+
"lrc_empty_result": "⚠️ LRC generation produced empty result."
|
| 213 |
}
|
| 214 |
}
|
acestep/gradio_ui/i18n/ja.json
CHANGED
|
@@ -148,8 +148,6 @@
|
|
| 148 |
"cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
|
| 149 |
"score_sensitivity_label": "品質スコア感度",
|
| 150 |
"score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
|
| 151 |
-
"attention_focus_label": "注意焦点スコアを出力(無効)",
|
| 152 |
-
"attention_focus_info": "注意焦点スコア分析を出力",
|
| 153 |
"think_label": "思考",
|
| 154 |
"parallel_thinking_label": "並列思考",
|
| 155 |
"generate_btn": "🎵 音楽を生成",
|
|
@@ -162,8 +160,12 @@
|
|
| 162 |
"send_to_src_btn": "🔗 ソースオーディオに送信",
|
| 163 |
"save_btn": "💾 保存",
|
| 164 |
"score_btn": "📊 スコア",
|
|
|
|
| 165 |
"quality_score_label": "品質スコア(サンプル {n})",
|
| 166 |
"quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
|
|
|
|
|
|
|
|
|
|
| 167 |
"generation_status": "生成ステータス",
|
| 168 |
"current_batch": "現在のバッチ",
|
| 169 |
"batch_indicator": "バッチ {current} / {total}",
|
|
@@ -173,11 +175,7 @@
|
|
| 173 |
"restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
|
| 174 |
"batch_results_title": "📁 バッチ結果と生成詳細",
|
| 175 |
"all_files_label": "📁 すべての生成ファイル(ダウンロード)",
|
| 176 |
-
"generation_details": "生成詳細"
|
| 177 |
-
"attention_analysis": "⚖️ 注意焦点スコア分析",
|
| 178 |
-
"attention_score": "注意焦点スコア(サンプル {n})",
|
| 179 |
-
"lyric_timestamps": "歌詞タイムスタンプ(サンプル {n})",
|
| 180 |
-
"attention_heatmap": "注意焦点スコアヒートマップ(サンプル {n})"
|
| 181 |
},
|
| 182 |
"messages": {
|
| 183 |
"no_audio_to_save": "❌ 保存するオーディオがありません",
|
|
@@ -206,6 +204,11 @@
|
|
| 206 |
"scoring_failed": "❌ エラー: バッチデータが見つかりません",
|
| 207 |
"no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
|
| 208 |
"score_failed": "❌ スコアリングに失敗しました: {error}",
|
| 209 |
-
"score_error": "❌ スコア計算エラー: {error}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
}
|
| 211 |
}
|
|
|
|
| 148 |
"cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
|
| 149 |
"score_sensitivity_label": "品質スコア感度",
|
| 150 |
"score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
|
|
|
|
|
|
|
| 151 |
"think_label": "思考",
|
| 152 |
"parallel_thinking_label": "並列思考",
|
| 153 |
"generate_btn": "🎵 音楽を生成",
|
|
|
|
| 160 |
"send_to_src_btn": "🔗 ソースオーディオに送信",
|
| 161 |
"save_btn": "💾 保存",
|
| 162 |
"score_btn": "📊 スコア",
|
| 163 |
+
"lrc_btn": "🎵 LRC",
|
| 164 |
"quality_score_label": "品質スコア(サンプル {n})",
|
| 165 |
"quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
|
| 166 |
+
"lrc_label": "歌詞タイムスタンプ(サンプル {n})",
|
| 167 |
+
"lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
|
| 168 |
+
"details_accordion": "📊 スコア & LRC",
|
| 169 |
"generation_status": "生成ステータス",
|
| 170 |
"current_batch": "現在のバッチ",
|
| 171 |
"batch_indicator": "バッチ {current} / {total}",
|
|
|
|
| 175 |
"restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
|
| 176 |
"batch_results_title": "📁 バッチ結果と生成詳細",
|
| 177 |
"all_files_label": "📁 すべての生成ファイル(ダウンロード)",
|
| 178 |
+
"generation_details": "生成詳細"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
},
|
| 180 |
"messages": {
|
| 181 |
"no_audio_to_save": "❌ 保存するオーディオがありません",
|
|
|
|
| 204 |
"scoring_failed": "❌ エラー: バッチデータが見つかりません",
|
| 205 |
"no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
|
| 206 |
"score_failed": "❌ スコアリングに失敗しました: {error}",
|
| 207 |
+
"score_error": "❌ スコア計算エラー: {error}",
|
| 208 |
+
"lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
|
| 209 |
+
"lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
|
| 210 |
+
"lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
|
| 211 |
+
"lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
|
| 212 |
+
"lrc_empty_result": "⚠️ LRC生成の結果が空です。"
|
| 213 |
}
|
| 214 |
}
|
acestep/gradio_ui/i18n/zh.json
CHANGED
|
@@ -148,8 +148,6 @@
|
|
| 148 |
"cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
|
| 149 |
"score_sensitivity_label": "质量评分敏感度",
|
| 150 |
"score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
|
| 151 |
-
"attention_focus_label": "输出注意力焦点分数(已禁用)",
|
| 152 |
-
"attention_focus_info": "输出注意力焦点分数分析",
|
| 153 |
"think_label": "思考",
|
| 154 |
"parallel_thinking_label": "并行思考",
|
| 155 |
"generate_btn": "🎵 生成音乐",
|
|
@@ -162,8 +160,12 @@
|
|
| 162 |
"send_to_src_btn": "🔗 发送到源音频",
|
| 163 |
"save_btn": "💾 保存",
|
| 164 |
"score_btn": "📊 评分",
|
|
|
|
| 165 |
"quality_score_label": "质量分数(样本 {n})",
|
| 166 |
"quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
|
|
|
|
|
|
|
|
|
|
| 167 |
"generation_status": "生成状态",
|
| 168 |
"current_batch": "当前批次",
|
| 169 |
"batch_indicator": "批次 {current} / {total}",
|
|
@@ -173,11 +175,7 @@
|
|
| 173 |
"restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
|
| 174 |
"batch_results_title": "📁 批量结果和生成详情",
|
| 175 |
"all_files_label": "📁 所有生成的文件(下载)",
|
| 176 |
-
"generation_details": "生成详情"
|
| 177 |
-
"attention_analysis": "⚖️ 注意力焦点分数分析",
|
| 178 |
-
"attention_score": "注意力焦点分数(样本 {n})",
|
| 179 |
-
"lyric_timestamps": "歌词时间戳(样本 {n})",
|
| 180 |
-
"attention_heatmap": "注意力焦点分数热图(样本 {n})"
|
| 181 |
},
|
| 182 |
"messages": {
|
| 183 |
"no_audio_to_save": "❌ 没有要保存的音频",
|
|
@@ -206,6 +204,11 @@
|
|
| 206 |
"scoring_failed": "❌ 错误: 未找到批次数据",
|
| 207 |
"no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
|
| 208 |
"score_failed": "❌ 评分失败: {error}",
|
| 209 |
-
"score_error": "❌ 计算分数时出错: {error}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
}
|
| 211 |
}
|
|
|
|
| 148 |
"cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
|
| 149 |
"score_sensitivity_label": "质量评分敏感度",
|
| 150 |
"score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
|
|
|
|
|
|
|
| 151 |
"think_label": "思考",
|
| 152 |
"parallel_thinking_label": "并行思考",
|
| 153 |
"generate_btn": "🎵 生成音乐",
|
|
|
|
| 160 |
"send_to_src_btn": "🔗 发送到源音频",
|
| 161 |
"save_btn": "💾 保存",
|
| 162 |
"score_btn": "📊 评分",
|
| 163 |
+
"lrc_btn": "🎵 LRC",
|
| 164 |
"quality_score_label": "质量分数(样本 {n})",
|
| 165 |
"quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
|
| 166 |
+
"lrc_label": "歌词时间戳(样本 {n})",
|
| 167 |
+
"lrc_placeholder": "点击'LRC'生成时间戳",
|
| 168 |
+
"details_accordion": "📊 评分与LRC",
|
| 169 |
"generation_status": "生成状态",
|
| 170 |
"current_batch": "当前批次",
|
| 171 |
"batch_indicator": "批次 {current} / {total}",
|
|
|
|
| 175 |
"restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
|
| 176 |
"batch_results_title": "📁 批量结果和生成详情",
|
| 177 |
"all_files_label": "📁 所有生成的文件(下载)",
|
| 178 |
+
"generation_details": "生成详情"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
},
|
| 180 |
"messages": {
|
| 181 |
"no_audio_to_save": "❌ 没有要保存的音频",
|
|
|
|
| 204 |
"scoring_failed": "❌ 错误: 未找到批次数据",
|
| 205 |
"no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
|
| 206 |
"score_failed": "❌ 评分失败: {error}",
|
| 207 |
+
"score_error": "❌ 计算分数时出错: {error}",
|
| 208 |
+
"lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
|
| 209 |
+
"lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
|
| 210 |
+
"lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
|
| 211 |
+
"lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
|
| 212 |
+
"lrc_empty_result": "⚠️ LRC生成结果为空。"
|
| 213 |
}
|
| 214 |
}
|
acestep/gradio_ui/interfaces/result.py
CHANGED
|
@@ -50,11 +50,24 @@ def create_results_section(dit_handler) -> dict:
|
|
| 50 |
size="sm",
|
| 51 |
scale=1
|
| 52 |
)
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
with gr.Column(visible=True) as audio_col_2:
|
| 59 |
generated_audio_2 = gr.Audio(
|
| 60 |
label=t("results.generated_music", n=2),
|
|
@@ -81,11 +94,24 @@ def create_results_section(dit_handler) -> dict:
|
|
| 81 |
size="sm",
|
| 82 |
scale=1
|
| 83 |
)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
with gr.Column(visible=False) as audio_col_3:
|
| 90 |
generated_audio_3 = gr.Audio(
|
| 91 |
label=t("results.generated_music", n=3),
|
|
@@ -112,11 +138,24 @@ def create_results_section(dit_handler) -> dict:
|
|
| 112 |
size="sm",
|
| 113 |
scale=1
|
| 114 |
)
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
with gr.Column(visible=False) as audio_col_4:
|
| 121 |
generated_audio_4 = gr.Audio(
|
| 122 |
label=t("results.generated_music", n=4),
|
|
@@ -143,11 +182,24 @@ def create_results_section(dit_handler) -> dict:
|
|
| 143 |
size="sm",
|
| 144 |
scale=1
|
| 145 |
)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# Second row for batch size 5-8 (initially hidden)
|
| 153 |
with gr.Row(visible=False) as audio_row_5_8:
|
|
@@ -162,11 +214,19 @@ def create_results_section(dit_handler) -> dict:
|
|
| 162 |
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 163 |
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 164 |
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
with gr.Column() as audio_col_6:
|
| 171 |
generated_audio_6 = gr.Audio(
|
| 172 |
label=t("results.generated_music", n=6),
|
|
@@ -178,11 +238,19 @@ def create_results_section(dit_handler) -> dict:
|
|
| 178 |
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 179 |
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 180 |
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
with gr.Column() as audio_col_7:
|
| 187 |
generated_audio_7 = gr.Audio(
|
| 188 |
label=t("results.generated_music", n=7),
|
|
@@ -194,11 +262,19 @@ def create_results_section(dit_handler) -> dict:
|
|
| 194 |
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 195 |
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 196 |
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
with gr.Column() as audio_col_8:
|
| 203 |
generated_audio_8 = gr.Audio(
|
| 204 |
label=t("results.generated_music", n=8),
|
|
@@ -210,11 +286,19 @@ def create_results_section(dit_handler) -> dict:
|
|
| 210 |
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 211 |
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 212 |
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
|
| 220 |
|
|
@@ -262,17 +346,6 @@ def create_results_section(dit_handler) -> dict:
|
|
| 262 |
interactive=False
|
| 263 |
)
|
| 264 |
generation_info = gr.Markdown(label=t("results.generation_details"))
|
| 265 |
-
|
| 266 |
-
with gr.Accordion(t("results.attention_analysis"), open=False):
|
| 267 |
-
with gr.Row():
|
| 268 |
-
with gr.Column():
|
| 269 |
-
align_score_1 = gr.Textbox(label=t("results.attention_score", n=1), interactive=False)
|
| 270 |
-
align_text_1 = gr.Textbox(label=t("results.lyric_timestamps", n=1), interactive=False, lines=10)
|
| 271 |
-
align_plot_1 = gr.Plot(label=t("results.attention_heatmap", n=1))
|
| 272 |
-
with gr.Column():
|
| 273 |
-
align_score_2 = gr.Textbox(label=t("results.attention_score", n=2), interactive=False)
|
| 274 |
-
align_text_2 = gr.Textbox(label=t("results.lyric_timestamps", n=2), interactive=False, lines=10)
|
| 275 |
-
align_plot_2 = gr.Plot(label=t("results.attention_heatmap", n=2))
|
| 276 |
|
| 277 |
return {
|
| 278 |
"lm_metadata_state": lm_metadata_state,
|
|
@@ -337,13 +410,31 @@ def create_results_section(dit_handler) -> dict:
|
|
| 337 |
"score_display_6": score_display_6,
|
| 338 |
"score_display_7": score_display_7,
|
| 339 |
"score_display_8": score_display_8,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
"generated_audio_batch": generated_audio_batch,
|
| 341 |
"generation_info": generation_info,
|
| 342 |
-
"align_score_1": align_score_1,
|
| 343 |
-
"align_text_1": align_text_1,
|
| 344 |
-
"align_plot_1": align_plot_1,
|
| 345 |
-
"align_score_2": align_score_2,
|
| 346 |
-
"align_text_2": align_text_2,
|
| 347 |
-
"align_plot_2": align_plot_2,
|
| 348 |
}
|
| 349 |
|
|
|
|
| 50 |
size="sm",
|
| 51 |
scale=1
|
| 52 |
)
|
| 53 |
+
lrc_btn_1 = gr.Button(
|
| 54 |
+
t("results.lrc_btn"),
|
| 55 |
+
variant="secondary",
|
| 56 |
+
size="sm",
|
| 57 |
+
scale=1
|
| 58 |
+
)
|
| 59 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_1:
|
| 60 |
+
score_display_1 = gr.Textbox(
|
| 61 |
+
label=t("results.quality_score_label", n=1),
|
| 62 |
+
interactive=False,
|
| 63 |
+
visible=False
|
| 64 |
+
)
|
| 65 |
+
lrc_display_1 = gr.Textbox(
|
| 66 |
+
label=t("results.lrc_label", n=1),
|
| 67 |
+
interactive=False,
|
| 68 |
+
lines=8,
|
| 69 |
+
visible=False
|
| 70 |
+
)
|
| 71 |
with gr.Column(visible=True) as audio_col_2:
|
| 72 |
generated_audio_2 = gr.Audio(
|
| 73 |
label=t("results.generated_music", n=2),
|
|
|
|
| 94 |
size="sm",
|
| 95 |
scale=1
|
| 96 |
)
|
| 97 |
+
lrc_btn_2 = gr.Button(
|
| 98 |
+
t("results.lrc_btn"),
|
| 99 |
+
variant="secondary",
|
| 100 |
+
size="sm",
|
| 101 |
+
scale=1
|
| 102 |
+
)
|
| 103 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_2:
|
| 104 |
+
score_display_2 = gr.Textbox(
|
| 105 |
+
label=t("results.quality_score_label", n=2),
|
| 106 |
+
interactive=False,
|
| 107 |
+
visible=False
|
| 108 |
+
)
|
| 109 |
+
lrc_display_2 = gr.Textbox(
|
| 110 |
+
label=t("results.lrc_label", n=2),
|
| 111 |
+
interactive=False,
|
| 112 |
+
lines=8,
|
| 113 |
+
visible=False
|
| 114 |
+
)
|
| 115 |
with gr.Column(visible=False) as audio_col_3:
|
| 116 |
generated_audio_3 = gr.Audio(
|
| 117 |
label=t("results.generated_music", n=3),
|
|
|
|
| 138 |
size="sm",
|
| 139 |
scale=1
|
| 140 |
)
|
| 141 |
+
lrc_btn_3 = gr.Button(
|
| 142 |
+
t("results.lrc_btn"),
|
| 143 |
+
variant="secondary",
|
| 144 |
+
size="sm",
|
| 145 |
+
scale=1
|
| 146 |
+
)
|
| 147 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_3:
|
| 148 |
+
score_display_3 = gr.Textbox(
|
| 149 |
+
label=t("results.quality_score_label", n=3),
|
| 150 |
+
interactive=False,
|
| 151 |
+
visible=False
|
| 152 |
+
)
|
| 153 |
+
lrc_display_3 = gr.Textbox(
|
| 154 |
+
label=t("results.lrc_label", n=3),
|
| 155 |
+
interactive=False,
|
| 156 |
+
lines=8,
|
| 157 |
+
visible=False
|
| 158 |
+
)
|
| 159 |
with gr.Column(visible=False) as audio_col_4:
|
| 160 |
generated_audio_4 = gr.Audio(
|
| 161 |
label=t("results.generated_music", n=4),
|
|
|
|
| 182 |
size="sm",
|
| 183 |
scale=1
|
| 184 |
)
|
| 185 |
+
lrc_btn_4 = gr.Button(
|
| 186 |
+
t("results.lrc_btn"),
|
| 187 |
+
variant="secondary",
|
| 188 |
+
size="sm",
|
| 189 |
+
scale=1
|
| 190 |
+
)
|
| 191 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_4:
|
| 192 |
+
score_display_4 = gr.Textbox(
|
| 193 |
+
label=t("results.quality_score_label", n=4),
|
| 194 |
+
interactive=False,
|
| 195 |
+
visible=False
|
| 196 |
+
)
|
| 197 |
+
lrc_display_4 = gr.Textbox(
|
| 198 |
+
label=t("results.lrc_label", n=4),
|
| 199 |
+
interactive=False,
|
| 200 |
+
lines=8,
|
| 201 |
+
visible=False
|
| 202 |
+
)
|
| 203 |
|
| 204 |
# Second row for batch size 5-8 (initially hidden)
|
| 205 |
with gr.Row(visible=False) as audio_row_5_8:
|
|
|
|
| 214 |
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 215 |
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 216 |
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 217 |
+
lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 218 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_5:
|
| 219 |
+
score_display_5 = gr.Textbox(
|
| 220 |
+
label=t("results.quality_score_label", n=5),
|
| 221 |
+
interactive=False,
|
| 222 |
+
visible=False
|
| 223 |
+
)
|
| 224 |
+
lrc_display_5 = gr.Textbox(
|
| 225 |
+
label=t("results.lrc_label", n=5),
|
| 226 |
+
interactive=False,
|
| 227 |
+
lines=8,
|
| 228 |
+
visible=False
|
| 229 |
+
)
|
| 230 |
with gr.Column() as audio_col_6:
|
| 231 |
generated_audio_6 = gr.Audio(
|
| 232 |
label=t("results.generated_music", n=6),
|
|
|
|
| 238 |
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 239 |
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 240 |
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 241 |
+
lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 242 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_6:
|
| 243 |
+
score_display_6 = gr.Textbox(
|
| 244 |
+
label=t("results.quality_score_label", n=6),
|
| 245 |
+
interactive=False,
|
| 246 |
+
visible=False
|
| 247 |
+
)
|
| 248 |
+
lrc_display_6 = gr.Textbox(
|
| 249 |
+
label=t("results.lrc_label", n=6),
|
| 250 |
+
interactive=False,
|
| 251 |
+
lines=8,
|
| 252 |
+
visible=False
|
| 253 |
+
)
|
| 254 |
with gr.Column() as audio_col_7:
|
| 255 |
generated_audio_7 = gr.Audio(
|
| 256 |
label=t("results.generated_music", n=7),
|
|
|
|
| 262 |
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 263 |
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 264 |
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 265 |
+
lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 266 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_7:
|
| 267 |
+
score_display_7 = gr.Textbox(
|
| 268 |
+
label=t("results.quality_score_label", n=7),
|
| 269 |
+
interactive=False,
|
| 270 |
+
visible=False
|
| 271 |
+
)
|
| 272 |
+
lrc_display_7 = gr.Textbox(
|
| 273 |
+
label=t("results.lrc_label", n=7),
|
| 274 |
+
interactive=False,
|
| 275 |
+
lines=8,
|
| 276 |
+
visible=False
|
| 277 |
+
)
|
| 278 |
with gr.Column() as audio_col_8:
|
| 279 |
generated_audio_8 = gr.Audio(
|
| 280 |
label=t("results.generated_music", n=8),
|
|
|
|
| 286 |
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 287 |
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 288 |
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 289 |
+
lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 290 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_8:
|
| 291 |
+
score_display_8 = gr.Textbox(
|
| 292 |
+
label=t("results.quality_score_label", n=8),
|
| 293 |
+
interactive=False,
|
| 294 |
+
visible=False
|
| 295 |
+
)
|
| 296 |
+
lrc_display_8 = gr.Textbox(
|
| 297 |
+
label=t("results.lrc_label", n=8),
|
| 298 |
+
interactive=False,
|
| 299 |
+
lines=8,
|
| 300 |
+
visible=False
|
| 301 |
+
)
|
| 302 |
|
| 303 |
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
|
| 304 |
|
|
|
|
| 346 |
interactive=False
|
| 347 |
)
|
| 348 |
generation_info = gr.Markdown(label=t("results.generation_details"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
return {
|
| 351 |
"lm_metadata_state": lm_metadata_state,
|
|
|
|
| 410 |
"score_display_6": score_display_6,
|
| 411 |
"score_display_7": score_display_7,
|
| 412 |
"score_display_8": score_display_8,
|
| 413 |
+
"lrc_btn_1": lrc_btn_1,
|
| 414 |
+
"lrc_btn_2": lrc_btn_2,
|
| 415 |
+
"lrc_btn_3": lrc_btn_3,
|
| 416 |
+
"lrc_btn_4": lrc_btn_4,
|
| 417 |
+
"lrc_btn_5": lrc_btn_5,
|
| 418 |
+
"lrc_btn_6": lrc_btn_6,
|
| 419 |
+
"lrc_btn_7": lrc_btn_7,
|
| 420 |
+
"lrc_btn_8": lrc_btn_8,
|
| 421 |
+
"lrc_display_1": lrc_display_1,
|
| 422 |
+
"lrc_display_2": lrc_display_2,
|
| 423 |
+
"lrc_display_3": lrc_display_3,
|
| 424 |
+
"lrc_display_4": lrc_display_4,
|
| 425 |
+
"lrc_display_5": lrc_display_5,
|
| 426 |
+
"lrc_display_6": lrc_display_6,
|
| 427 |
+
"lrc_display_7": lrc_display_7,
|
| 428 |
+
"lrc_display_8": lrc_display_8,
|
| 429 |
+
"details_accordion_1": details_accordion_1,
|
| 430 |
+
"details_accordion_2": details_accordion_2,
|
| 431 |
+
"details_accordion_3": details_accordion_3,
|
| 432 |
+
"details_accordion_4": details_accordion_4,
|
| 433 |
+
"details_accordion_5": details_accordion_5,
|
| 434 |
+
"details_accordion_6": details_accordion_6,
|
| 435 |
+
"details_accordion_7": details_accordion_7,
|
| 436 |
+
"details_accordion_8": details_accordion_8,
|
| 437 |
"generated_audio_batch": generated_audio_batch,
|
| 438 |
"generation_info": generation_info,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
}
|
| 440 |
|
acestep/handler.py
CHANGED
|
@@ -31,6 +31,7 @@ from acestep.constants import (
|
|
| 31 |
SFT_GEN_PROMPT,
|
| 32 |
DEFAULT_DIT_INSTRUCTION,
|
| 33 |
)
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
warnings.filterwarnings("ignore")
|
|
@@ -65,13 +66,7 @@ class AceStepHandler:
|
|
| 65 |
self.batch_size = 2
|
| 66 |
|
| 67 |
# Custom layers config
|
| 68 |
-
self.custom_layers_config = {
|
| 69 |
-
2: [6, 7],
|
| 70 |
-
3: [10, 11],
|
| 71 |
-
4: [3],
|
| 72 |
-
5: [8, 9, 11],
|
| 73 |
-
6: [8]
|
| 74 |
-
}
|
| 75 |
self.offload_to_cpu = False
|
| 76 |
self.offload_dit_to_cpu = False
|
| 77 |
self.current_offload_cost = 0.0
|
|
@@ -1953,6 +1948,23 @@ class AceStepHandler:
|
|
| 1953 |
}
|
| 1954 |
logger.info("[service_generate] Generating audio...")
|
| 1955 |
with self._load_model_context("model"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1956 |
outputs = self.model.generate_audio(**generate_kwargs)
|
| 1957 |
|
| 1958 |
# Add intermediate information to outputs for extra_outputs
|
|
@@ -1962,6 +1974,12 @@ class AceStepHandler:
|
|
| 1962 |
outputs["spans"] = spans
|
| 1963 |
outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
|
| 1964 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1965 |
return outputs
|
| 1966 |
|
| 1967 |
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
|
@@ -2268,16 +2286,27 @@ class AceStepHandler:
|
|
| 2268 |
spans = outputs.get("spans", []) # List of tuples
|
| 2269 |
latent_masks = outputs.get("latent_masks") # [batch, T]
|
| 2270 |
|
| 2271 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2272 |
extra_outputs = {
|
| 2273 |
-
"pred_latents": pred_latents.cpu() if pred_latents is not None else None,
|
| 2274 |
-
"target_latents": target_latents_input.cpu() if target_latents_input is not None else None,
|
| 2275 |
-
"src_latents": src_latents.cpu() if src_latents is not None else None,
|
| 2276 |
-
"chunk_masks": chunk_masks.cpu() if chunk_masks is not None else None,
|
| 2277 |
-
"latent_masks": latent_masks.cpu() if latent_masks is not None else None,
|
| 2278 |
"spans": spans,
|
| 2279 |
"time_costs": time_costs,
|
| 2280 |
"seed_value": seed_value_for_ui,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2281 |
}
|
| 2282 |
|
| 2283 |
# Build audios list with tensor data (no file paths, no UUIDs, handled outside)
|
|
@@ -2307,3 +2336,220 @@ class AceStepHandler:
|
|
| 2307 |
"success": False,
|
| 2308 |
"error": str(e),
|
| 2309 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
SFT_GEN_PROMPT,
|
| 32 |
DEFAULT_DIT_INSTRUCTION,
|
| 33 |
)
|
| 34 |
+
from acestep.dit_alignment_score import MusicStampsAligner
|
| 35 |
|
| 36 |
|
| 37 |
warnings.filterwarnings("ignore")
|
|
|
|
| 66 |
self.batch_size = 2
|
| 67 |
|
| 68 |
# Custom layers config
|
| 69 |
+
self.custom_layers_config = {2: [6], 3: [10, 11], 4: [3], 5: [8, 9], 6: [8]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
self.offload_to_cpu = False
|
| 71 |
self.offload_dit_to_cpu = False
|
| 72 |
self.current_offload_cost = 0.0
|
|
|
|
| 1948 |
}
|
| 1949 |
logger.info("[service_generate] Generating audio...")
|
| 1950 |
with self._load_model_context("model"):
|
| 1951 |
+
# Prepare condition tensors first (for LRC timestamp generation)
|
| 1952 |
+
encoder_hidden_states, encoder_attention_mask, context_latents = self.model.prepare_condition(
|
| 1953 |
+
text_hidden_states=text_hidden_states,
|
| 1954 |
+
text_attention_mask=text_attention_mask,
|
| 1955 |
+
lyric_hidden_states=lyric_hidden_states,
|
| 1956 |
+
lyric_attention_mask=lyric_attention_mask,
|
| 1957 |
+
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
|
| 1958 |
+
refer_audio_order_mask=refer_audio_order_mask,
|
| 1959 |
+
hidden_states=src_latents,
|
| 1960 |
+
attention_mask=torch.ones(src_latents.shape[0], src_latents.shape[1], device=src_latents.device, dtype=src_latents.dtype),
|
| 1961 |
+
silence_latent=self.silence_latent,
|
| 1962 |
+
src_latents=src_latents,
|
| 1963 |
+
chunk_masks=chunk_mask,
|
| 1964 |
+
is_covers=is_covers,
|
| 1965 |
+
precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz,
|
| 1966 |
+
)
|
| 1967 |
+
|
| 1968 |
outputs = self.model.generate_audio(**generate_kwargs)
|
| 1969 |
|
| 1970 |
# Add intermediate information to outputs for extra_outputs
|
|
|
|
| 1974 |
outputs["spans"] = spans
|
| 1975 |
outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
|
| 1976 |
|
| 1977 |
+
# Add condition tensors for LRC timestamp generation
|
| 1978 |
+
outputs["encoder_hidden_states"] = encoder_hidden_states
|
| 1979 |
+
outputs["encoder_attention_mask"] = encoder_attention_mask
|
| 1980 |
+
outputs["context_latents"] = context_latents
|
| 1981 |
+
outputs["lyric_token_idss"] = lyric_token_idss
|
| 1982 |
+
|
| 1983 |
return outputs
|
| 1984 |
|
| 1985 |
def tiled_decode(self, latents, chunk_size=512, overlap=64):
|
|
|
|
| 2286 |
spans = outputs.get("spans", []) # List of tuples
|
| 2287 |
latent_masks = outputs.get("latent_masks") # [batch, T]
|
| 2288 |
|
| 2289 |
+
# Extract condition tensors for LRC timestamp generation
|
| 2290 |
+
encoder_hidden_states = outputs.get("encoder_hidden_states")
|
| 2291 |
+
encoder_attention_mask = outputs.get("encoder_attention_mask")
|
| 2292 |
+
context_latents = outputs.get("context_latents")
|
| 2293 |
+
lyric_token_idss = outputs.get("lyric_token_idss")
|
| 2294 |
+
|
| 2295 |
+
# Move all tensors to CPU to save VRAM (detach to release computation graph)
|
| 2296 |
extra_outputs = {
|
| 2297 |
+
"pred_latents": pred_latents.detach().cpu() if pred_latents is not None else None,
|
| 2298 |
+
"target_latents": target_latents_input.detach().cpu() if target_latents_input is not None else None,
|
| 2299 |
+
"src_latents": src_latents.detach().cpu() if src_latents is not None else None,
|
| 2300 |
+
"chunk_masks": chunk_masks.detach().cpu() if chunk_masks is not None else None,
|
| 2301 |
+
"latent_masks": latent_masks.detach().cpu() if latent_masks is not None else None,
|
| 2302 |
"spans": spans,
|
| 2303 |
"time_costs": time_costs,
|
| 2304 |
"seed_value": seed_value_for_ui,
|
| 2305 |
+
# Condition tensors for LRC timestamp generation
|
| 2306 |
+
"encoder_hidden_states": encoder_hidden_states.detach().cpu() if encoder_hidden_states is not None else None,
|
| 2307 |
+
"encoder_attention_mask": encoder_attention_mask.detach().cpu() if encoder_attention_mask is not None else None,
|
| 2308 |
+
"context_latents": context_latents.detach().cpu() if context_latents is not None else None,
|
| 2309 |
+
"lyric_token_idss": lyric_token_idss.detach().cpu() if lyric_token_idss is not None else None,
|
| 2310 |
}
|
| 2311 |
|
| 2312 |
# Build audios list with tensor data (no file paths, no UUIDs, handled outside)
|
|
|
|
| 2336 |
"success": False,
|
| 2337 |
"error": str(e),
|
| 2338 |
}
|
| 2339 |
+
|
| 2340 |
+
@torch.no_grad()
|
| 2341 |
+
def get_lyric_timestamp(
|
| 2342 |
+
self,
|
| 2343 |
+
pred_latent: torch.Tensor,
|
| 2344 |
+
encoder_hidden_states: torch.Tensor,
|
| 2345 |
+
encoder_attention_mask: torch.Tensor,
|
| 2346 |
+
context_latents: torch.Tensor,
|
| 2347 |
+
lyric_token_ids: torch.Tensor,
|
| 2348 |
+
total_duration_seconds: float,
|
| 2349 |
+
vocal_language: str = "en",
|
| 2350 |
+
inference_steps: int = 8,
|
| 2351 |
+
seed: int = 42,
|
| 2352 |
+
custom_layers_config: Optional[Dict] = None,
|
| 2353 |
+
) -> Dict[str, Any]:
|
| 2354 |
+
"""
|
| 2355 |
+
Generate lyrics timestamps from generated audio latents using cross-attention alignment.
|
| 2356 |
+
|
| 2357 |
+
This method adds noise to the final pred_latent and re-infers one step to get
|
| 2358 |
+
cross-attention matrices, then uses DTW to align lyrics tokens with audio frames.
|
| 2359 |
+
|
| 2360 |
+
Args:
|
| 2361 |
+
pred_latent: Generated latent tensor [batch, T, D]
|
| 2362 |
+
encoder_hidden_states: Cached encoder hidden states
|
| 2363 |
+
encoder_attention_mask: Cached encoder attention mask
|
| 2364 |
+
context_latents: Cached context latents
|
| 2365 |
+
lyric_token_ids: Tokenized lyrics tensor [batch, seq_len]
|
| 2366 |
+
total_duration_seconds: Total audio duration in seconds
|
| 2367 |
+
vocal_language: Language code for lyrics header parsing
|
| 2368 |
+
inference_steps: Number of inference steps (for noise level calculation)
|
| 2369 |
+
seed: Random seed for noise generation
|
| 2370 |
+
custom_layers_config: Dict mapping layer indices to head indices
|
| 2371 |
+
|
| 2372 |
+
Returns:
|
| 2373 |
+
Dict containing:
|
| 2374 |
+
- lrc_text: LRC formatted lyrics with timestamps
|
| 2375 |
+
- sentence_timestamps: List of SentenceTimestamp objects
|
| 2376 |
+
- token_timestamps: List of TokenTimestamp objects
|
| 2377 |
+
- success: Whether generation succeeded
|
| 2378 |
+
- error: Error message if failed
|
| 2379 |
+
"""
|
| 2380 |
+
from transformers.cache_utils import EncoderDecoderCache, DynamicCache
|
| 2381 |
+
|
| 2382 |
+
if self.model is None:
|
| 2383 |
+
return {
|
| 2384 |
+
"lrc_text": "",
|
| 2385 |
+
"sentence_timestamps": [],
|
| 2386 |
+
"token_timestamps": [],
|
| 2387 |
+
"success": False,
|
| 2388 |
+
"error": "Model not initialized"
|
| 2389 |
+
}
|
| 2390 |
+
|
| 2391 |
+
if custom_layers_config is None:
|
| 2392 |
+
custom_layers_config = self.custom_layers_config
|
| 2393 |
+
|
| 2394 |
+
try:
|
| 2395 |
+
# Move tensors to device
|
| 2396 |
+
device = self.device
|
| 2397 |
+
dtype = self.dtype
|
| 2398 |
+
|
| 2399 |
+
pred_latent = pred_latent.to(device=device, dtype=dtype)
|
| 2400 |
+
encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype)
|
| 2401 |
+
encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype)
|
| 2402 |
+
context_latents = context_latents.to(device=device, dtype=dtype)
|
| 2403 |
+
|
| 2404 |
+
bsz = pred_latent.shape[0]
|
| 2405 |
+
|
| 2406 |
+
# Calculate noise level: t_last = 1.0 / inference_steps
|
| 2407 |
+
t_last_val = 1.0 / inference_steps
|
| 2408 |
+
t_curr_tensor = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype)
|
| 2409 |
+
|
| 2410 |
+
x1 = pred_latent
|
| 2411 |
+
|
| 2412 |
+
# Generate noise
|
| 2413 |
+
if seed is None:
|
| 2414 |
+
x0 = torch.randn_like(x1)
|
| 2415 |
+
else:
|
| 2416 |
+
generator = torch.Generator(device=device).manual_seed(int(seed))
|
| 2417 |
+
x0 = torch.randn(x1.shape, generator=generator, device=device, dtype=dtype)
|
| 2418 |
+
|
| 2419 |
+
# Add noise to pred_latent: xt = t * noise + (1 - t) * x1
|
| 2420 |
+
xt = t_last_val * x0 + (1.0 - t_last_val) * x1
|
| 2421 |
+
|
| 2422 |
+
xt_in = xt
|
| 2423 |
+
t_in = t_curr_tensor
|
| 2424 |
+
|
| 2425 |
+
# Get null condition embedding
|
| 2426 |
+
encoder_hidden_states_in = encoder_hidden_states
|
| 2427 |
+
encoder_attention_mask_in = encoder_attention_mask
|
| 2428 |
+
context_latents_in = context_latents
|
| 2429 |
+
latent_length = x1.shape[1]
|
| 2430 |
+
attention_mask = torch.ones(bsz, latent_length, device=device, dtype=dtype)
|
| 2431 |
+
attention_mask_in = attention_mask
|
| 2432 |
+
past_key_values = None
|
| 2433 |
+
|
| 2434 |
+
# Run decoder with output_attentions=True
|
| 2435 |
+
with self._load_model_context("model"):
|
| 2436 |
+
decoder = self.model.decoder
|
| 2437 |
+
decoder_outputs = decoder(
|
| 2438 |
+
hidden_states=xt_in,
|
| 2439 |
+
timestep=t_in,
|
| 2440 |
+
timestep_r=t_in,
|
| 2441 |
+
attention_mask=attention_mask_in,
|
| 2442 |
+
encoder_hidden_states=encoder_hidden_states_in,
|
| 2443 |
+
use_cache=False,
|
| 2444 |
+
past_key_values=past_key_values,
|
| 2445 |
+
encoder_attention_mask=encoder_attention_mask_in,
|
| 2446 |
+
context_latents=context_latents_in,
|
| 2447 |
+
output_attentions=True,
|
| 2448 |
+
custom_layers_config=custom_layers_config,
|
| 2449 |
+
enable_early_exit=True
|
| 2450 |
+
)
|
| 2451 |
+
|
| 2452 |
+
# Extract cross-attention matrices
|
| 2453 |
+
if decoder_outputs[2] is None:
|
| 2454 |
+
return {
|
| 2455 |
+
"lrc_text": "",
|
| 2456 |
+
"sentence_timestamps": [],
|
| 2457 |
+
"token_timestamps": [],
|
| 2458 |
+
"success": False,
|
| 2459 |
+
"error": "Model did not return attentions"
|
| 2460 |
+
}
|
| 2461 |
+
|
| 2462 |
+
cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None)
|
| 2463 |
+
|
| 2464 |
+
captured_layers_list = []
|
| 2465 |
+
for layer_attn in cross_attns:
|
| 2466 |
+
# Skip None values (layers that didn't return attention)
|
| 2467 |
+
if layer_attn is None:
|
| 2468 |
+
continue
|
| 2469 |
+
# Only take conditional part (first half of batch)
|
| 2470 |
+
cond_attn = layer_attn[:bsz]
|
| 2471 |
+
layer_matrix = cond_attn.transpose(-1, -2)
|
| 2472 |
+
captured_layers_list.append(layer_matrix)
|
| 2473 |
+
|
| 2474 |
+
if not captured_layers_list:
|
| 2475 |
+
return {
|
| 2476 |
+
"lrc_text": "",
|
| 2477 |
+
"sentence_timestamps": [],
|
| 2478 |
+
"token_timestamps": [],
|
| 2479 |
+
"success": False,
|
| 2480 |
+
"error": "No valid attention layers returned"
|
| 2481 |
+
}
|
| 2482 |
+
|
| 2483 |
+
stacked = torch.stack(captured_layers_list)
|
| 2484 |
+
if bsz == 1:
|
| 2485 |
+
all_layers_matrix = stacked.squeeze(1)
|
| 2486 |
+
else:
|
| 2487 |
+
all_layers_matrix = stacked
|
| 2488 |
+
|
| 2489 |
+
# Process lyric token IDs to extract pure lyrics
|
| 2490 |
+
if isinstance(lyric_token_ids, torch.Tensor):
|
| 2491 |
+
raw_lyric_ids = lyric_token_ids[0].tolist()
|
| 2492 |
+
else:
|
| 2493 |
+
raw_lyric_ids = lyric_token_ids
|
| 2494 |
+
|
| 2495 |
+
# Parse header to find lyrics start position
|
| 2496 |
+
header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n"
|
| 2497 |
+
header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False)
|
| 2498 |
+
start_idx = len(header_ids)
|
| 2499 |
+
|
| 2500 |
+
# Find end of lyrics (before endoftext token)
|
| 2501 |
+
try:
|
| 2502 |
+
end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token
|
| 2503 |
+
except ValueError:
|
| 2504 |
+
end_idx = len(raw_lyric_ids)
|
| 2505 |
+
|
| 2506 |
+
pure_lyric_ids = raw_lyric_ids[start_idx:end_idx]
|
| 2507 |
+
pure_lyric_matrix = all_layers_matrix[:, :, start_idx:end_idx, :]
|
| 2508 |
+
|
| 2509 |
+
# Create aligner and generate timestamps
|
| 2510 |
+
aligner = MusicStampsAligner(self.text_tokenizer)
|
| 2511 |
+
|
| 2512 |
+
align_info = aligner.stamps_align_info(
|
| 2513 |
+
attention_matrix=pure_lyric_matrix,
|
| 2514 |
+
lyrics_tokens=pure_lyric_ids,
|
| 2515 |
+
total_duration_seconds=total_duration_seconds,
|
| 2516 |
+
custom_config=custom_layers_config,
|
| 2517 |
+
return_matrices=False,
|
| 2518 |
+
violence_level=2.0,
|
| 2519 |
+
medfilt_width=1,
|
| 2520 |
+
)
|
| 2521 |
+
|
| 2522 |
+
if align_info.get("calc_matrix") is None:
|
| 2523 |
+
return {
|
| 2524 |
+
"lrc_text": "",
|
| 2525 |
+
"sentence_timestamps": [],
|
| 2526 |
+
"token_timestamps": [],
|
| 2527 |
+
"success": False,
|
| 2528 |
+
"error": align_info.get("error", "Failed to process attention matrix")
|
| 2529 |
+
}
|
| 2530 |
+
|
| 2531 |
+
# Generate timestamps
|
| 2532 |
+
result = aligner.get_timestamps_and_lrc(
|
| 2533 |
+
calc_matrix=align_info["calc_matrix"],
|
| 2534 |
+
lyrics_tokens=pure_lyric_ids,
|
| 2535 |
+
total_duration_seconds=total_duration_seconds
|
| 2536 |
+
)
|
| 2537 |
+
|
| 2538 |
+
return {
|
| 2539 |
+
"lrc_text": result["lrc_text"],
|
| 2540 |
+
"sentence_timestamps": result["sentence_timestamps"],
|
| 2541 |
+
"token_timestamps": result["token_timestamps"],
|
| 2542 |
+
"success": True,
|
| 2543 |
+
"error": None
|
| 2544 |
+
}
|
| 2545 |
+
|
| 2546 |
+
except Exception as e:
|
| 2547 |
+
error_msg = f"Error generating timestamps: {str(e)}"
|
| 2548 |
+
logger.exception("[get_lyric_timestamp] Failed")
|
| 2549 |
+
return {
|
| 2550 |
+
"lrc_text": "",
|
| 2551 |
+
"sentence_timestamps": [],
|
| 2552 |
+
"token_timestamps": [],
|
| 2553 |
+
"success": False,
|
| 2554 |
+
"error": error_msg
|
| 2555 |
+
}
|
pyproject.toml
CHANGED
|
@@ -30,7 +30,7 @@ dependencies = [
|
|
| 30 |
"uvicorn[standard]>=0.27.0",
|
| 31 |
|
| 32 |
# Local third-party packages
|
| 33 |
-
"nano-vllm @
|
| 34 |
]
|
| 35 |
|
| 36 |
[project.scripts]
|
|
@@ -41,8 +41,8 @@ acestep-api = "acestep.api_server:main"
|
|
| 41 |
requires = ["hatchling"]
|
| 42 |
build-backend = "hatchling.build"
|
| 43 |
|
| 44 |
-
[
|
| 45 |
-
dev
|
| 46 |
|
| 47 |
[[tool.uv.index]]
|
| 48 |
name = "pytorch"
|
|
|
|
| 30 |
"uvicorn[standard]>=0.27.0",
|
| 31 |
|
| 32 |
# Local third-party packages
|
| 33 |
+
"nano-vllm @ {root:uri}/acestep/third_parts/nano-vllm",
|
| 34 |
]
|
| 35 |
|
| 36 |
[project.scripts]
|
|
|
|
| 41 |
requires = ["hatchling"]
|
| 42 |
build-backend = "hatchling.build"
|
| 43 |
|
| 44 |
+
[dependency-groups]
|
| 45 |
+
dev = []
|
| 46 |
|
| 47 |
[[tool.uv.index]]
|
| 48 |
name = "pytorch"
|