ChuxiJ commited on
Commit
6922ca4
·
1 Parent(s): bb87271

add lyc support

Browse files
acestep/dit_alignment_score.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiT Alignment Score Module
3
+
4
+ This module provides lyrics-to-audio alignment using cross-attention matrices
5
+ from DiT model for generating LRC timestamps.
6
+
7
+ Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
8
+ """
9
+ import numba
10
+ import torch
11
+ import numpy as np
12
+ import torch.nn.functional as F
13
+ from dataclasses import dataclass, asdict
14
+ from typing import List, Dict, Any, Optional
15
+
16
+
17
+ # ================= Data Classes =================
18
+ @dataclass
19
+ class TokenTimestamp:
20
+ """Stores per-token timing information."""
21
+ token_id: int
22
+ text: str
23
+ start: float
24
+ end: float
25
+ probability: float
26
+
27
+
28
+ @dataclass
29
+ class SentenceTimestamp:
30
+ """Stores per-sentence timing information with token list."""
31
+ text: str
32
+ start: float
33
+ end: float
34
+ tokens: List[TokenTimestamp]
35
+ confidence: float
36
+
37
+
38
+ # ================= DTW Algorithm (Numba Optimized) =================
39
+ @numba.jit(nopython=True)
40
+ def dtw_cpu(x: np.ndarray):
41
+ """
42
+ Dynamic Time Warping algorithm optimized with Numba.
43
+
44
+ Args:
45
+ x: Cost matrix of shape [N, M]
46
+
47
+ Returns:
48
+ Tuple of (text_indices, time_indices) arrays
49
+ """
50
+ N, M = x.shape
51
+ # Use float32 for memory efficiency
52
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
53
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32)
54
+ cost[0, 0] = 0
55
+
56
+ for j in range(1, M + 1):
57
+ for i in range(1, N + 1):
58
+ c0 = cost[i - 1, j - 1]
59
+ c1 = cost[i - 1, j]
60
+ c2 = cost[i, j - 1]
61
+
62
+ if c0 < c1 and c0 < c2:
63
+ c, t = c0, 0
64
+ elif c1 < c0 and c1 < c2:
65
+ c, t = c1, 1
66
+ else:
67
+ c, t = c2, 2
68
+
69
+ cost[i, j] = x[i - 1, j - 1] + c
70
+ trace[i, j] = t
71
+
72
+ return _backtrace(trace, N, M)
73
+
74
+
75
+ @numba.jit(nopython=True)
76
+ def _backtrace(trace: np.ndarray, N: int, M: int):
77
+ """
78
+ Optimized backtrace function for DTW.
79
+
80
+ Args:
81
+ trace: Trace matrix of shape (N+1, M+1)
82
+ N, M: Original matrix dimensions
83
+
84
+ Returns:
85
+ Path array of shape (2, path_len) - first row is text indices, second is time indices
86
+ """
87
+ # Boundary handling
88
+ trace[0, :] = 2
89
+ trace[:, 0] = 1
90
+
91
+ # Pre-allocate array, max path length is N+M
92
+ max_path_len = N + M
93
+ path = np.zeros((2, max_path_len), dtype=np.int32)
94
+
95
+ i, j = N, M
96
+ path_idx = max_path_len - 1
97
+
98
+ while i > 0 or j > 0:
99
+ path[0, path_idx] = i - 1 # text index
100
+ path[1, path_idx] = j - 1 # time index
101
+ path_idx -= 1
102
+
103
+ t = trace[i, j]
104
+ if t == 0:
105
+ i -= 1
106
+ j -= 1
107
+ elif t == 1:
108
+ i -= 1
109
+ elif t == 2:
110
+ j -= 1
111
+ else:
112
+ break
113
+
114
+ actual_len = max_path_len - path_idx - 1
115
+ return path[:, path_idx + 1:max_path_len]
116
+
117
+
118
+ # ================= Utility Functions =================
119
+ def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
120
+ """
121
+ Apply median filter to tensor.
122
+
123
+ Args:
124
+ x: Input tensor
125
+ filter_width: Width of median filter
126
+
127
+ Returns:
128
+ Filtered tensor
129
+ """
130
+ pad_width = filter_width // 2
131
+ if x.shape[-1] <= pad_width:
132
+ return x
133
+ if x.ndim == 2:
134
+ x = x[None, :]
135
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
136
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
137
+ if result.ndim > 2:
138
+ result = result.squeeze(0)
139
+ return result
140
+
141
+
142
+ # ================= Main Aligner Class =================
143
+ class MusicStampsAligner:
144
+ """
145
+ Aligner class for generating lyrics timestamps from cross-attention matrices.
146
+
147
+ Uses bidirectional consensus denoising and DTW for alignment.
148
+ """
149
+
150
+ def __init__(self, tokenizer):
151
+ """
152
+ Initialize the aligner.
153
+
154
+ Args:
155
+ tokenizer: Text tokenizer for decoding tokens
156
+ """
157
+ self.tokenizer = tokenizer
158
+
159
+ def _apply_bidirectional_consensus(
160
+ self,
161
+ weights_stack: torch.Tensor,
162
+ violence_level: float,
163
+ medfilt_width: int
164
+ ) -> tuple:
165
+ """
166
+ Core denoising logic using bidirectional consensus.
167
+
168
+ Args:
169
+ weights_stack: Attention weights [Heads, Tokens, Frames]
170
+ violence_level: Denoising strength coefficient
171
+ medfilt_width: Median filter width
172
+
173
+ Returns:
174
+ Tuple of (calc_matrix, energy_matrix) as numpy arrays
175
+ """
176
+ # A. Bidirectional Consensus
177
+ row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
178
+ col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
179
+ processed = row_prob * col_prob
180
+
181
+ # 1. Row suppression (kill horizontal crossing lines)
182
+ row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
183
+ processed = processed - (violence_level * row_medians)
184
+ processed = torch.relu(processed)
185
+
186
+ # 2. Column suppression (kill vertical crossing lines)
187
+ col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
188
+ processed = processed - (violence_level * col_medians)
189
+ processed = torch.relu(processed)
190
+
191
+ # C. Power sharpening
192
+ processed = processed ** 2
193
+
194
+ # Energy matrix for confidence
195
+ energy_matrix = processed.mean(dim=0).cpu().numpy()
196
+
197
+ # D. Z-Score normalization
198
+ std, mean = torch.std_mean(processed, unbiased=False)
199
+ weights_processed = (processed - mean) / (std + 1e-9)
200
+
201
+ # E. Median filtering
202
+ weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
203
+ calc_matrix = weights_processed.mean(dim=0).numpy()
204
+
205
+ return calc_matrix, energy_matrix
206
+
207
+ def _preprocess_attention(
208
+ self,
209
+ attention_matrix: torch.Tensor,
210
+ custom_config: Dict[int, List[int]],
211
+ violence_level: float,
212
+ medfilt_width: int = 7
213
+ ) -> tuple:
214
+ """
215
+ Preprocess attention matrix for alignment.
216
+
217
+ Args:
218
+ attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
219
+ custom_config: Dict mapping layer indices to head indices
220
+ violence_level: Denoising strength
221
+ medfilt_width: Median filter width
222
+
223
+ Returns:
224
+ Tuple of (calc_matrix, energy_matrix, visual_matrix)
225
+ """
226
+ if not isinstance(attention_matrix, torch.Tensor):
227
+ weights = torch.tensor(attention_matrix)
228
+ else:
229
+ weights = attention_matrix.clone()
230
+
231
+ weights = weights.cpu().float()
232
+
233
+ selected_tensors = []
234
+ for layer_idx, head_indices in custom_config.items():
235
+ for head_idx in head_indices:
236
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
237
+ head_matrix = weights[layer_idx, head_idx]
238
+ selected_tensors.append(head_matrix)
239
+
240
+ if not selected_tensors:
241
+ return None, None, None
242
+
243
+ # Stack selected heads: [Heads, Tokens, Frames]
244
+ weights_stack = torch.stack(selected_tensors, dim=0)
245
+ visual_matrix = weights_stack.mean(dim=0).numpy()
246
+
247
+ calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
248
+ weights_stack, violence_level, medfilt_width
249
+ )
250
+
251
+ return calc_matrix, energy_matrix, visual_matrix
252
+
253
+ def stamps_align_info(
254
+ self,
255
+ attention_matrix: torch.Tensor,
256
+ lyrics_tokens: List[int],
257
+ total_duration_seconds: float,
258
+ custom_config: Dict[int, List[int]],
259
+ return_matrices: bool = False,
260
+ violence_level: float = 2.0,
261
+ medfilt_width: int = 1
262
+ ) -> Dict[str, Any]:
263
+ """
264
+ Get alignment information from attention matrix.
265
+
266
+ Args:
267
+ attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
268
+ lyrics_tokens: List of lyrics token IDs
269
+ total_duration_seconds: Total audio duration in seconds
270
+ custom_config: Dict mapping layer indices to head indices
271
+ return_matrices: Whether to return intermediate matrices
272
+ violence_level: Denoising strength
273
+ medfilt_width: Median filter width
274
+
275
+ Returns:
276
+ Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
277
+ and optionally energy_matrix and vis_matrix
278
+ """
279
+ calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
280
+ attention_matrix, custom_config, violence_level, medfilt_width
281
+ )
282
+
283
+ if calc_matrix is None:
284
+ return {
285
+ "calc_matrix": None,
286
+ "lyrics_tokens": lyrics_tokens,
287
+ "total_duration_seconds": total_duration_seconds,
288
+ "error": "No valid attention heads found"
289
+ }
290
+
291
+ return_dict = {
292
+ "calc_matrix": calc_matrix,
293
+ "lyrics_tokens": lyrics_tokens,
294
+ "total_duration_seconds": total_duration_seconds
295
+ }
296
+
297
+ if return_matrices:
298
+ return_dict['energy_matrix'] = energy_matrix
299
+ return_dict['vis_matrix'] = visual_matrix
300
+
301
+ return return_dict
302
+
303
+ def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
304
+ """
305
+ Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
306
+
307
+ For Chinese and other multi-byte characters, the tokenizer may split them
308
+ into multiple byte-level tokens. Decoding each token individually produces
309
+ invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
310
+ to correctly track which characters each token contributes.
311
+
312
+ Args:
313
+ token_ids: List of token IDs
314
+
315
+ Returns:
316
+ List of decoded text for each token position
317
+ """
318
+ decoded_tokens = []
319
+ prev_bytes = b""
320
+
321
+ for i in range(len(token_ids)):
322
+ # Decode tokens from start to current position
323
+ current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
324
+ current_bytes = current_text.encode('utf-8', errors='surrogatepass')
325
+
326
+ # The contribution of current token is the new bytes added
327
+ if len(current_bytes) >= len(prev_bytes):
328
+ new_bytes = current_bytes[len(prev_bytes):]
329
+ # Try to decode the new bytes; if incomplete, use empty string
330
+ try:
331
+ token_text = new_bytes.decode('utf-8')
332
+ except UnicodeDecodeError:
333
+ # Incomplete UTF-8 sequence, this token doesn't complete a character
334
+ token_text = ""
335
+ else:
336
+ # Edge case: current decode is shorter (shouldn't happen normally)
337
+ token_text = ""
338
+
339
+ decoded_tokens.append(token_text)
340
+ prev_bytes = current_bytes
341
+
342
+ return decoded_tokens
343
+
344
+ def token_timestamps(
345
+ self,
346
+ calc_matrix: np.ndarray,
347
+ lyrics_tokens: List[int],
348
+ total_duration_seconds: float
349
+ ) -> List[TokenTimestamp]:
350
+ """
351
+ Generate per-token timestamps using DTW.
352
+
353
+ Args:
354
+ calc_matrix: Processed attention matrix [Tokens, Frames]
355
+ lyrics_tokens: List of token IDs
356
+ total_duration_seconds: Total audio duration
357
+
358
+ Returns:
359
+ List of TokenTimestamp objects
360
+ """
361
+ n_frames = calc_matrix.shape[-1]
362
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
363
+
364
+ seconds_per_frame = total_duration_seconds / n_frames
365
+ alignment_results = []
366
+
367
+ # Use incremental decoding to properly handle multi-byte UTF-8 characters
368
+ decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
369
+
370
+ for i in range(len(lyrics_tokens)):
371
+ mask = (text_indices == i)
372
+
373
+ if not np.any(mask):
374
+ start = alignment_results[-1].end if alignment_results else 0.0
375
+ end = start
376
+ token_conf = 0.0
377
+ else:
378
+ times = time_indices[mask] * seconds_per_frame
379
+ start = times[0]
380
+ end = times[-1]
381
+ token_conf = 0.0
382
+
383
+ if end < start:
384
+ end = start
385
+
386
+ alignment_results.append(TokenTimestamp(
387
+ token_id=lyrics_tokens[i],
388
+ text=decoded_tokens[i],
389
+ start=float(start),
390
+ end=float(end),
391
+ probability=token_conf
392
+ ))
393
+
394
+ return alignment_results
395
+
396
+ def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
397
+ """
398
+ Decode a sentence by decoding all token IDs together.
399
+ This avoids UTF-8 encoding issues from joining individual token texts.
400
+
401
+ Args:
402
+ tokens: List of TokenTimestamp objects
403
+
404
+ Returns:
405
+ Properly decoded sentence text
406
+ """
407
+ token_ids = [t.token_id for t in tokens]
408
+ return self.tokenizer.decode(token_ids, skip_special_tokens=False)
409
+
410
+ def sentence_timestamps(
411
+ self,
412
+ token_alignment: List[TokenTimestamp]
413
+ ) -> List[SentenceTimestamp]:
414
+ """
415
+ Group token timestamps into sentence timestamps.
416
+
417
+ Args:
418
+ token_alignment: List of TokenTimestamp objects
419
+
420
+ Returns:
421
+ List of SentenceTimestamp objects
422
+ """
423
+ results = []
424
+ current_tokens = []
425
+
426
+ for token in token_alignment:
427
+ current_tokens.append(token)
428
+
429
+ if '\n' in token.text:
430
+ # Decode all token IDs together to avoid UTF-8 issues
431
+ full_text = self._decode_sentence_from_tokens(current_tokens)
432
+
433
+ if full_text.strip():
434
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
435
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
436
+
437
+ results.append(SentenceTimestamp(
438
+ text=full_text.strip(),
439
+ start=round(current_tokens[0].start, 3),
440
+ end=round(current_tokens[-1].end, 3),
441
+ tokens=list(current_tokens),
442
+ confidence=sent_conf
443
+ ))
444
+
445
+ current_tokens = []
446
+
447
+ # Handle last sentence
448
+ if current_tokens:
449
+ # Decode all token IDs together to avoid UTF-8 issues
450
+ full_text = self._decode_sentence_from_tokens(current_tokens)
451
+ if full_text.strip():
452
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
453
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
454
+
455
+ results.append(SentenceTimestamp(
456
+ text=full_text.strip(),
457
+ start=round(current_tokens[0].start, 3),
458
+ end=round(current_tokens[-1].end, 3),
459
+ tokens=list(current_tokens),
460
+ confidence=sent_conf
461
+ ))
462
+
463
+ # Normalize confidence scores
464
+ if results:
465
+ all_scores = [s.confidence for s in results]
466
+ min_score = min(all_scores)
467
+ max_score = max(all_scores)
468
+ score_range = max_score - min_score
469
+
470
+ if score_range > 1e-9:
471
+ for s in results:
472
+ normalized_score = (s.confidence - min_score) / score_range
473
+ s.confidence = round(normalized_score, 2)
474
+ else:
475
+ for s in results:
476
+ s.confidence = round(s.confidence, 2)
477
+
478
+ return results
479
+
480
+ def format_lrc(
481
+ self,
482
+ sentence_timestamps: List[SentenceTimestamp],
483
+ include_end_time: bool = False
484
+ ) -> str:
485
+ """
486
+ Format sentence timestamps as LRC lyrics format.
487
+
488
+ Args:
489
+ sentence_timestamps: List of SentenceTimestamp objects
490
+ include_end_time: Whether to include end time (enhanced LRC format)
491
+
492
+ Returns:
493
+ LRC formatted string
494
+ """
495
+ lines = []
496
+
497
+ for sentence in sentence_timestamps:
498
+ # Convert seconds to mm:ss.xx format
499
+ start_minutes = int(sentence.start // 60)
500
+ start_seconds = sentence.start % 60
501
+
502
+ if include_end_time:
503
+ end_minutes = int(sentence.end // 60)
504
+ end_seconds = sentence.end % 60
505
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
506
+ else:
507
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
508
+
509
+ # Clean the text (remove structural tags like [verse], [chorus])
510
+ text = sentence.text
511
+
512
+ lines.append(f"{timestamp}{text}")
513
+
514
+ return "\n".join(lines)
515
+
516
+ def get_timestamps_and_lrc(
517
+ self,
518
+ calc_matrix: np.ndarray,
519
+ lyrics_tokens: List[int],
520
+ total_duration_seconds: float
521
+ ) -> Dict[str, Any]:
522
+ """
523
+ Convenience method to get both timestamps and LRC in one call.
524
+
525
+ Args:
526
+ calc_matrix: Processed attention matrix
527
+ lyrics_tokens: List of token IDs
528
+ total_duration_seconds: Total audio duration
529
+
530
+ Returns:
531
+ Dict containing token_timestamps, sentence_timestamps, and lrc_text
532
+ """
533
+ token_stamps = self.token_timestamps(
534
+ calc_matrix=calc_matrix,
535
+ lyrics_tokens=lyrics_tokens,
536
+ total_duration_seconds=total_duration_seconds
537
+ )
538
+
539
+ sentence_stamps = self.sentence_timestamps(token_stamps)
540
+ lrc_text = self.format_lrc(sentence_stamps)
541
+
542
+ return {
543
+ "token_timestamps": token_stamps,
544
+ "sentence_timestamps": sentence_stamps,
545
+ "lrc_text": lrc_text
546
+ }
547
+
acestep/gradio_ui/events/__init__.py CHANGED
@@ -358,19 +358,49 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
358
  )
359
 
360
  # ========== Score Calculation Handlers ==========
 
 
 
 
 
 
361
  for btn_idx in range(1, 9):
362
  results_section[f"score_btn_{btn_idx}"].click(
363
- fn=lambda sample_idx, scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
364
- llm_handler, sample_idx, scale, batch_idx, queue
365
- ),
366
  inputs=[
367
- gr.State(value=btn_idx),
368
  generation_section["score_scale"],
369
  results_section["current_batch_index"],
370
  results_section["batch_queue"],
371
  ],
372
- outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
 
 
 
 
373
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
  def generation_wrapper(*args):
375
  yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
376
  # ========== Generation Handler ==========
@@ -438,12 +468,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
438
  results_section["generation_info"],
439
  results_section["status_output"],
440
  generation_section["seed"],
441
- results_section["align_score_1"],
442
- results_section["align_text_1"],
443
- results_section["align_plot_1"],
444
- results_section["align_score_2"],
445
- results_section["align_text_2"],
446
- results_section["align_plot_2"],
447
  results_section["score_display_1"],
448
  results_section["score_display_2"],
449
  results_section["score_display_3"],
 
358
  )
359
 
360
  # ========== Score Calculation Handlers ==========
361
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
362
+ def make_score_handler(idx):
363
+ return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
364
+ llm_handler, idx, scale, batch_idx, queue
365
+ )
366
+
367
  for btn_idx in range(1, 9):
368
  results_section[f"score_btn_{btn_idx}"].click(
369
+ fn=make_score_handler(btn_idx),
 
 
370
  inputs=[
 
371
  generation_section["score_scale"],
372
  results_section["current_batch_index"],
373
  results_section["batch_queue"],
374
  ],
375
+ outputs=[
376
+ results_section[f"score_display_{btn_idx}"],
377
+ results_section[f"details_accordion_{btn_idx}"],
378
+ results_section["batch_queue"]
379
+ ]
380
  )
381
+
382
+ # ========== LRC Timestamp Handlers ==========
383
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
384
+ def make_lrc_handler(idx):
385
+ return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
386
+ dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
387
+ )
388
+
389
+ for btn_idx in range(1, 9):
390
+ results_section[f"lrc_btn_{btn_idx}"].click(
391
+ fn=make_lrc_handler(btn_idx),
392
+ inputs=[
393
+ results_section["current_batch_index"],
394
+ results_section["batch_queue"],
395
+ generation_section["vocal_language"],
396
+ generation_section["inference_steps"],
397
+ ],
398
+ outputs=[
399
+ results_section[f"lrc_display_{btn_idx}"],
400
+ results_section[f"details_accordion_{btn_idx}"]
401
+ ]
402
+ )
403
+
404
  def generation_wrapper(*args):
405
  yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
406
  # ========== Generation Handler ==========
 
468
  results_section["generation_info"],
469
  results_section["status_output"],
470
  generation_section["seed"],
 
 
 
 
 
 
471
  results_section["score_display_1"],
472
  results_section["score_display_2"],
473
  results_section["score_display_3"],
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -141,6 +141,7 @@ def store_batch_in_queue(
141
  batch_size=2,
142
  generation_params=None,
143
  lm_generated_metadata=None,
 
144
  status="completed"
145
  ):
146
  """Store batch results in queue with ALL generation parameters
@@ -152,6 +153,7 @@ def store_batch_in_queue(
152
  batch_size: Batch size used for this batch
153
  generation_params: Complete dictionary of ALL generation parameters used
154
  lm_generated_metadata: LM-generated metadata for scoring (optional)
 
155
  """
156
  batch_queue[batch_index] = {
157
  "status": status,
@@ -164,6 +166,7 @@ def store_batch_in_queue(
164
  "batch_size": batch_size, # Store batch size
165
  "generation_params": generation_params if generation_params else {}, # Store ALL parameters
166
  "lm_generated_metadata": lm_generated_metadata, # Store LM metadata for scoring
 
167
  "timestamp": datetime.datetime.now().isoformat()
168
  }
169
  return batch_queue
@@ -355,12 +358,6 @@ def generate_with_progress(
355
  audio_conversion_start_time = time_module.time()
356
  total_auto_score_time = 0.0
357
 
358
- align_score_1 = ""
359
- align_text_1 = ""
360
- align_plot_1 = None
361
- align_score_2 = ""
362
- align_text_2 = ""
363
- align_plot_2 = None
364
  updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
365
 
366
  # Build initial generation_info (will be updated with post-processing times at the end)
@@ -373,7 +370,7 @@ def generate_with_progress(
373
  )
374
 
375
  if not result.success:
376
- yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 26
377
  return
378
 
379
  audios = result.audios
@@ -421,8 +418,6 @@ def generate_with_progress(
421
  generation_info,
422
  status_message,
423
  seed_value_for_ui,
424
- # Align plot placeholders (assume no need to update in real time)
425
- gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(),
426
  # Scores
427
  scores_ui_updates[0], scores_ui_updates[1], scores_ui_updates[2], scores_ui_updates[3], scores_ui_updates[4], scores_ui_updates[5], scores_ui_updates[6], scores_ui_updates[7],
428
  updated_audio_codes,
@@ -431,6 +426,7 @@ def generate_with_progress(
431
  audio_codes_ui_updates[4], audio_codes_ui_updates[5], audio_codes_ui_updates[6], audio_codes_ui_updates[7],
432
  lm_generated_metadata,
433
  is_format_caption,
 
434
  )
435
  else:
436
  # If i exceeds the generated count (e.g., batch=2, i=2..7), do not yield
@@ -467,7 +463,6 @@ def generate_with_progress(
467
  generation_info,
468
  "Generation Complete",
469
  seed_value_for_ui,
470
- align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2,
471
  final_scores_list[0], final_scores_list[1], final_scores_list[2], final_scores_list[3],
472
  final_scores_list[4], final_scores_list[5], final_scores_list[6], final_scores_list[7],
473
  updated_audio_codes,
@@ -475,6 +470,7 @@ def generate_with_progress(
475
  final_codes_list[4], final_codes_list[5], final_codes_list[6], final_codes_list[7],
476
  lm_generated_metadata,
477
  is_format_caption,
 
478
  )
479
 
480
 
@@ -595,7 +591,7 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
595
  batch_queue: Batch queue containing historical generation data
596
  """
597
  if current_batch_index not in batch_queue:
598
- return t("messages.scoring_failed"), batch_queue
599
 
600
  batch_data = batch_queue[current_batch_index]
601
  params = batch_data.get("generation_params", {})
@@ -642,7 +638,106 @@ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale,
642
  batch_queue[current_batch_index]["scores"] = [""] * 8
643
  batch_queue[current_batch_index]["scores"][sample_idx - 1] = score_display
644
 
645
- return score_display, batch_queue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
 
648
  def capture_current_params(
@@ -758,7 +853,9 @@ def generate_with_batch_management(
758
  final_result_from_inner = partial_result
759
  # current_batch_index, total_batches, batch_queue, next_params,
760
  # batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
761
- yield partial_result + (
 
 
762
  gr.skip(), gr.skip(), gr.skip(), gr.skip(),
763
  gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
764
  )
@@ -766,21 +863,23 @@ def generate_with_batch_management(
766
  all_audio_paths = result[8]
767
 
768
  if all_audio_paths is None:
769
-
770
- yield result + (
 
771
  gr.skip(), gr.skip(), gr.skip(), gr.skip(),
772
  gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
773
  )
774
  return
775
 
776
  # Extract results from generation (使用 result 下标访问)
 
777
  generation_info = result[9]
778
  seed_value_for_ui = result[11]
779
- lm_generated_metadata = result[35] # Fixed: lm_metadata is at index 35, not 34
780
 
781
  # Extract codes
782
- generated_codes_single = result[26]
783
- generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
784
 
785
  # Determine which codes to store based on mode
786
  if allow_lm_batch and batch_size_input >= 2:
@@ -839,6 +938,9 @@ def generate_with_batch_management(
839
  next_params["text2music_audio_code_string"] = ""
840
  next_params["random_seed_checkbox"] = True
841
 
 
 
 
842
  # Store current batch in queue
843
  batch_queue = store_batch_in_queue(
844
  batch_queue,
@@ -851,6 +953,7 @@ def generate_with_batch_management(
851
  batch_size=int(batch_size_input),
852
  generation_params=saved_params,
853
  lm_generated_metadata=lm_generated_metadata,
 
854
  status="completed"
855
  )
856
 
@@ -870,7 +973,9 @@ def generate_with_batch_management(
870
 
871
  # 4. Yield final result (includes Batch UI updates)
872
  # The result here is already a tuple structure
873
- yield result + (
 
 
874
  current_batch_index,
875
  total_batches,
876
  batch_queue,
@@ -1040,14 +1145,15 @@ def generate_next_batch_background(
1040
  final_result = partial_result
1041
 
1042
  # Extract results from final_result
 
1043
  all_audio_paths = final_result[8] # generated_audio_batch
1044
  generation_info = final_result[9]
1045
  seed_value_for_ui = final_result[11]
1046
- lm_generated_metadata = final_result[35] # Fixed: lm_metadata is at index 35, not 34
1047
 
1048
  # Extract codes
1049
- generated_codes_single = final_result[26]
1050
- generated_codes_batch = [final_result[27], final_result[28], final_result[29], final_result[30], final_result[31], final_result[32], final_result[33], final_result[34]]
1051
 
1052
  # Determine which codes to store
1053
  batch_size = params.get("batch_size_input", 2)
@@ -1070,6 +1176,7 @@ def generate_next_batch_background(
1070
  logger.info(f" - codes_to_store: STRING with {len(codes_to_store) if codes_to_store else 0} chars")
1071
 
1072
  # Store next batch in queue with codes, batch settings, and ALL generation params
 
1073
  batch_queue = store_batch_in_queue(
1074
  batch_queue,
1075
  next_batch_idx,
@@ -1081,6 +1188,7 @@ def generate_next_batch_background(
1081
  batch_size=int(batch_size),
1082
  generation_params=params,
1083
  lm_generated_metadata=lm_generated_metadata,
 
1084
  status="completed"
1085
  )
1086
 
 
141
  batch_size=2,
142
  generation_params=None,
143
  lm_generated_metadata=None,
144
+ extra_outputs=None,
145
  status="completed"
146
  ):
147
  """Store batch results in queue with ALL generation parameters
 
153
  batch_size: Batch size used for this batch
154
  generation_params: Complete dictionary of ALL generation parameters used
155
  lm_generated_metadata: LM-generated metadata for scoring (optional)
156
+ extra_outputs: Dictionary containing pred_latents, encoder_hidden_states, etc. for LRC generation
157
  """
158
  batch_queue[batch_index] = {
159
  "status": status,
 
166
  "batch_size": batch_size, # Store batch size
167
  "generation_params": generation_params if generation_params else {}, # Store ALL parameters
168
  "lm_generated_metadata": lm_generated_metadata, # Store LM metadata for scoring
169
+ "extra_outputs": extra_outputs if extra_outputs else {}, # Store extra outputs for LRC generation
170
  "timestamp": datetime.datetime.now().isoformat()
171
  }
172
  return batch_queue
 
358
  audio_conversion_start_time = time_module.time()
359
  total_auto_score_time = 0.0
360
 
 
 
 
 
 
 
361
  updated_audio_codes = text2music_audio_code_string if not think_checkbox else ""
362
 
363
  # Build initial generation_info (will be updated with post-processing times at the end)
 
370
  )
371
 
372
  if not result.success:
373
+ yield (None,) * 8 + (None, generation_info, result.status_message) + (gr.skip(),) * 20 + (None,) # +1 for extra_outputs
374
  return
375
 
376
  audios = result.audios
 
418
  generation_info,
419
  status_message,
420
  seed_value_for_ui,
 
 
421
  # Scores
422
  scores_ui_updates[0], scores_ui_updates[1], scores_ui_updates[2], scores_ui_updates[3], scores_ui_updates[4], scores_ui_updates[5], scores_ui_updates[6], scores_ui_updates[7],
423
  updated_audio_codes,
 
426
  audio_codes_ui_updates[4], audio_codes_ui_updates[5], audio_codes_ui_updates[6], audio_codes_ui_updates[7],
427
  lm_generated_metadata,
428
  is_format_caption,
429
+ None, # Placeholder for extra_outputs (only filled in final yield)
430
  )
431
  else:
432
  # If i exceeds the generated count (e.g., batch=2, i=2..7), do not yield
 
463
  generation_info,
464
  "Generation Complete",
465
  seed_value_for_ui,
 
466
  final_scores_list[0], final_scores_list[1], final_scores_list[2], final_scores_list[3],
467
  final_scores_list[4], final_scores_list[5], final_scores_list[6], final_scores_list[7],
468
  updated_audio_codes,
 
470
  final_codes_list[4], final_codes_list[5], final_codes_list[6], final_codes_list[7],
471
  lm_generated_metadata,
472
  is_format_caption,
473
+ result.extra_outputs, # extra_outputs for LRC generation
474
  )
475
 
476
 
 
591
  batch_queue: Batch queue containing historical generation data
592
  """
593
  if current_batch_index not in batch_queue:
594
+ return gr.skip(), gr.skip(), batch_queue
595
 
596
  batch_data = batch_queue[current_batch_index]
597
  params = batch_data.get("generation_params", {})
 
638
  batch_queue[current_batch_index]["scores"] = [""] * 8
639
  batch_queue[current_batch_index]["scores"][sample_idx - 1] = score_display
640
 
641
+ # Return: score_display (content + visible), accordion visible, batch_queue
642
+ return (
643
+ gr.update(value=score_display, visible=True), # score_display with content
644
+ gr.update(visible=True), # details_accordion
645
+ batch_queue
646
+ )
647
+
648
+
649
+ def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_queue, vocal_language, inference_steps):
650
+ """
651
+ Generate LRC timestamps for a specific audio sample.
652
+
653
+ This function retrieves cached generation data from batch_queue and calls
654
+ the handler's get_lyric_timestamp method to generate LRC format lyrics.
655
+
656
+ Args:
657
+ dit_handler: DiT handler instance with get_lyric_timestamp method
658
+ sample_idx: Which sample to generate LRC for (1-8)
659
+ current_batch_index: Current batch index in batch_queue
660
+ batch_queue: Dictionary storing all batch generation data
661
+ vocal_language: Language code for lyrics
662
+ inference_steps: Number of inference steps used in generation
663
+
664
+ Returns:
665
+ LRC formatted string or error message
666
+ """
667
+ import torch
668
+
669
+ if current_batch_index not in batch_queue:
670
+ return gr.skip(), gr.skip()
671
+
672
+ batch_data = batch_queue[current_batch_index]
673
+ extra_outputs = batch_data.get("extra_outputs", {})
674
+
675
+ # Check if required data is available
676
+ if not extra_outputs:
677
+ return gr.update(value=t("messages.lrc_no_extra_outputs"), visible=True), gr.update(visible=True)
678
+
679
+ pred_latents = extra_outputs.get("pred_latents")
680
+ encoder_hidden_states = extra_outputs.get("encoder_hidden_states")
681
+ encoder_attention_mask = extra_outputs.get("encoder_attention_mask")
682
+ context_latents = extra_outputs.get("context_latents")
683
+ lyric_token_idss = extra_outputs.get("lyric_token_idss")
684
+
685
+ if any(x is None for x in [pred_latents, encoder_hidden_states, encoder_attention_mask, context_latents, lyric_token_idss]):
686
+ return gr.update(value=t("messages.lrc_missing_tensors"), visible=True), gr.update(visible=True)
687
+
688
+ # Adjust sample_idx to 0-based
689
+ sample_idx_0based = sample_idx - 1
690
+
691
+ # Check if sample exists in batch
692
+ batch_size = pred_latents.shape[0]
693
+ if sample_idx_0based >= batch_size:
694
+ return gr.update(value=t("messages.lrc_sample_not_exist"), visible=True), gr.update(visible=True)
695
+
696
+ # Extract the specific sample's data
697
+ try:
698
+ # Get audio duration from batch data
699
+ params = batch_data.get("generation_params", {})
700
+ audio_duration = params.get("audio_duration", -1)
701
+
702
+ # Calculate duration from latents if not specified
703
+ if audio_duration is None or audio_duration <= 0:
704
+ # latent_length * frames_per_second_ratio ≈ audio_duration
705
+ # Assuming 25 Hz latent rate: latent_length / 25 = duration
706
+ latent_length = pred_latents.shape[1]
707
+ audio_duration = latent_length / 25.0 # 25 Hz latent rate
708
+
709
+ # Get the sample's data (keep batch dimension for handler)
710
+ sample_pred_latent = pred_latents[sample_idx_0based:sample_idx_0based+1]
711
+ sample_encoder_hidden_states = encoder_hidden_states[sample_idx_0based:sample_idx_0based+1]
712
+ sample_encoder_attention_mask = encoder_attention_mask[sample_idx_0based:sample_idx_0based+1]
713
+ sample_context_latents = context_latents[sample_idx_0based:sample_idx_0based+1]
714
+ sample_lyric_token_ids = lyric_token_idss[sample_idx_0based:sample_idx_0based+1]
715
+
716
+ # Call handler to generate timestamps
717
+ result = dit_handler.get_lyric_timestamp(
718
+ pred_latent=sample_pred_latent,
719
+ encoder_hidden_states=sample_encoder_hidden_states,
720
+ encoder_attention_mask=sample_encoder_attention_mask,
721
+ context_latents=sample_context_latents,
722
+ lyric_token_ids=sample_lyric_token_ids,
723
+ total_duration_seconds=float(audio_duration),
724
+ vocal_language=vocal_language or "en",
725
+ inference_steps=int(inference_steps),
726
+ seed=42, # Use fixed seed for reproducibility
727
+ )
728
+
729
+ if result.get("success"):
730
+ lrc_text = result.get("lrc_text", "")
731
+ if not lrc_text:
732
+ return gr.update(value=t("messages.lrc_empty_result"), visible=True), gr.update(visible=True)
733
+ return gr.update(value=lrc_text, visible=True), gr.update(visible=True)
734
+ else:
735
+ error_msg = result.get("error", "Unknown error")
736
+ return gr.update(value=f"❌ {error_msg}", visible=True), gr.update(visible=True)
737
+
738
+ except Exception as e:
739
+ logger.exception("[generate_lrc_handler] Error generating LRC")
740
+ return gr.update(value=f"❌ Error: {str(e)}", visible=True), gr.update(visible=True)
741
 
742
 
743
  def capture_current_params(
 
853
  final_result_from_inner = partial_result
854
  # current_batch_index, total_batches, batch_queue, next_params,
855
  # batch_indicator_text, prev_btn, next_btn, next_status, restore_btn
856
+ # Slice off extra_outputs (last item) before re-yielding to UI
857
+ ui_result = partial_result[:-1] if len(partial_result) > 31 else partial_result
858
+ yield ui_result + (
859
  gr.skip(), gr.skip(), gr.skip(), gr.skip(),
860
  gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
861
  )
 
863
  all_audio_paths = result[8]
864
 
865
  if all_audio_paths is None:
866
+ # Slice off extra_outputs before yielding to UI
867
+ ui_result = result[:-1] if len(result) > 31 else result
868
+ yield ui_result + (
869
  gr.skip(), gr.skip(), gr.skip(), gr.skip(),
870
  gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip()
871
  )
872
  return
873
 
874
  # Extract results from generation (使用 result 下标访问)
875
+ # New indices after removing 6 align_* items (was 12-17, now shifted down by 6)
876
  generation_info = result[9]
877
  seed_value_for_ui = result[11]
878
+ lm_generated_metadata = result[29] # was 35, now 29
879
 
880
  # Extract codes
881
+ generated_codes_single = result[20] # was 26, now 20
882
+ generated_codes_batch = [result[21], result[22], result[23], result[24], result[25], result[26], result[27], result[28]] # was 27-34, now 21-28
883
 
884
  # Determine which codes to store based on mode
885
  if allow_lm_batch and batch_size_input >= 2:
 
938
  next_params["text2music_audio_code_string"] = ""
939
  next_params["random_seed_checkbox"] = True
940
 
941
+ # Extract extra_outputs from result tuple (index 31)
942
+ extra_outputs_from_result = result[31] if len(result) > 31 else {}
943
+
944
  # Store current batch in queue
945
  batch_queue = store_batch_in_queue(
946
  batch_queue,
 
953
  batch_size=int(batch_size_input),
954
  generation_params=saved_params,
955
  lm_generated_metadata=lm_generated_metadata,
956
+ extra_outputs=extra_outputs_from_result, # Store extra outputs for LRC generation
957
  status="completed"
958
  )
959
 
 
973
 
974
  # 4. Yield final result (includes Batch UI updates)
975
  # The result here is already a tuple structure
976
+ # Slice off extra_outputs (last item) before yielding to UI - it's already stored in batch_queue
977
+ ui_result = result[:-1] if len(result) > 31 else result
978
+ yield ui_result + (
979
  current_batch_index,
980
  total_batches,
981
  batch_queue,
 
1145
  final_result = partial_result
1146
 
1147
  # Extract results from final_result
1148
+ # Indices shifted by -6 after removing align_* items
1149
  all_audio_paths = final_result[8] # generated_audio_batch
1150
  generation_info = final_result[9]
1151
  seed_value_for_ui = final_result[11]
1152
+ lm_generated_metadata = final_result[29] # was 35, now 29
1153
 
1154
  # Extract codes
1155
+ generated_codes_single = final_result[20] # was 26, now 20
1156
+ generated_codes_batch = [final_result[21], final_result[22], final_result[23], final_result[24], final_result[25], final_result[26], final_result[27], final_result[28]] # was 27-34, now 21-28
1157
 
1158
  # Determine which codes to store
1159
  batch_size = params.get("batch_size_input", 2)
 
1176
  logger.info(f" - codes_to_store: STRING with {len(codes_to_store) if codes_to_store else 0} chars")
1177
 
1178
  # Store next batch in queue with codes, batch settings, and ALL generation params
1179
+ # Note: extra_outputs not available for background batches (LRC not supported for auto-gen batches)
1180
  batch_queue = store_batch_in_queue(
1181
  batch_queue,
1182
  next_batch_idx,
 
1188
  batch_size=int(batch_size),
1189
  generation_params=params,
1190
  lm_generated_metadata=lm_generated_metadata,
1191
+ extra_outputs=None, # Not available for background batches
1192
  status="completed"
1193
  )
1194
 
acestep/gradio_ui/i18n/en.json CHANGED
@@ -148,8 +148,6 @@
148
  "cover_strength_info": "Control how many denoising steps use cover mode",
149
  "score_sensitivity_label": "Quality Score Sensitivity",
150
  "score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
151
- "attention_focus_label": "Output Attention Focus Score (disabled)",
152
- "attention_focus_info": "Output attention focus score analysis",
153
  "think_label": "Think",
154
  "parallel_thinking_label": "ParallelThinking",
155
  "generate_btn": "🎵 Generate Music",
@@ -162,8 +160,12 @@
162
  "send_to_src_btn": "🔗 Send To Src Audio",
163
  "save_btn": "💾 Save",
164
  "score_btn": "📊 Score",
 
165
  "quality_score_label": "Quality Score (Sample {n})",
166
  "quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
 
 
 
167
  "generation_status": "Generation Status",
168
  "current_batch": "Current Batch",
169
  "batch_indicator": "Batch {current} / {total}",
@@ -173,11 +175,7 @@
173
  "restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
174
  "batch_results_title": "📁 Batch Results & Generation Details",
175
  "all_files_label": "📁 All Generated Files (Download)",
176
- "generation_details": "Generation Details",
177
- "attention_analysis": "⚖️ Attention Focus Score Analysis",
178
- "attention_score": "Attention Focus Score (Sample {n})",
179
- "lyric_timestamps": "Lyric Timestamps (Sample {n})",
180
- "attention_heatmap": "Attention Focus Score Heatmap (Sample {n})"
181
  },
182
  "messages": {
183
  "no_audio_to_save": "❌ No audio to save",
@@ -206,6 +204,11 @@
206
  "scoring_failed": "❌ Error: Batch data not found",
207
  "no_codes": "❌ No audio codes available. Please generate music first.",
208
  "score_failed": "❌ Scoring failed: {error}",
209
- "score_error": "❌ Error calculating score: {error}"
 
 
 
 
 
210
  }
211
  }
 
148
  "cover_strength_info": "Control how many denoising steps use cover mode",
149
  "score_sensitivity_label": "Quality Score Sensitivity",
150
  "score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
 
 
151
  "think_label": "Think",
152
  "parallel_thinking_label": "ParallelThinking",
153
  "generate_btn": "🎵 Generate Music",
 
160
  "send_to_src_btn": "🔗 Send To Src Audio",
161
  "save_btn": "💾 Save",
162
  "score_btn": "📊 Score",
163
+ "lrc_btn": "🎵 LRC",
164
  "quality_score_label": "Quality Score (Sample {n})",
165
  "quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
166
+ "lrc_label": "Lyrics Timestamps (Sample {n})",
167
+ "lrc_placeholder": "Click 'LRC' to generate timestamps",
168
+ "details_accordion": "📊 Score & LRC",
169
  "generation_status": "Generation Status",
170
  "current_batch": "Current Batch",
171
  "batch_indicator": "Batch {current} / {total}",
 
175
  "restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
176
  "batch_results_title": "📁 Batch Results & Generation Details",
177
  "all_files_label": "📁 All Generated Files (Download)",
178
+ "generation_details": "Generation Details"
 
 
 
 
179
  },
180
  "messages": {
181
  "no_audio_to_save": "❌ No audio to save",
 
204
  "scoring_failed": "❌ Error: Batch data not found",
205
  "no_codes": "❌ No audio codes available. Please generate music first.",
206
  "score_failed": "❌ Scoring failed: {error}",
207
+ "score_error": "❌ Error calculating score: {error}",
208
+ "lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
209
+ "lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
210
+ "lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
211
+ "lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
212
+ "lrc_empty_result": "⚠️ LRC generation produced empty result."
213
  }
214
  }
acestep/gradio_ui/i18n/ja.json CHANGED
@@ -148,8 +148,6 @@
148
  "cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
149
  "score_sensitivity_label": "品質スコア感度",
150
  "score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
151
- "attention_focus_label": "注意焦点スコアを出力(無効)",
152
- "attention_focus_info": "注意焦点スコア分析を出力",
153
  "think_label": "思考",
154
  "parallel_thinking_label": "並列思考",
155
  "generate_btn": "🎵 音楽を生成",
@@ -162,8 +160,12 @@
162
  "send_to_src_btn": "🔗 ソースオーディオに送信",
163
  "save_btn": "💾 保存",
164
  "score_btn": "📊 スコア",
 
165
  "quality_score_label": "品質スコア(サンプル {n})",
166
  "quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
 
 
 
167
  "generation_status": "生成ステータス",
168
  "current_batch": "現在のバッチ",
169
  "batch_indicator": "バッチ {current} / {total}",
@@ -173,11 +175,7 @@
173
  "restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
174
  "batch_results_title": "📁 バッチ結果と生成詳細",
175
  "all_files_label": "📁 すべての生成ファイル(ダウンロード)",
176
- "generation_details": "生成詳細",
177
- "attention_analysis": "⚖️ 注意焦点スコア分析",
178
- "attention_score": "注意焦点スコア(サンプル {n})",
179
- "lyric_timestamps": "歌詞タイムスタンプ(サンプル {n})",
180
- "attention_heatmap": "注意焦点スコアヒートマップ(サンプル {n})"
181
  },
182
  "messages": {
183
  "no_audio_to_save": "❌ 保存するオーディオがありません",
@@ -206,6 +204,11 @@
206
  "scoring_failed": "❌ エラー: バッチデータが見つかりません",
207
  "no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
208
  "score_failed": "❌ スコアリングに失敗しました: {error}",
209
- "score_error": "❌ スコア計算エラー: {error}"
 
 
 
 
 
210
  }
211
  }
 
148
  "cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
149
  "score_sensitivity_label": "品質スコア感度",
150
  "score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
 
 
151
  "think_label": "思考",
152
  "parallel_thinking_label": "並列思考",
153
  "generate_btn": "🎵 音楽を生成",
 
160
  "send_to_src_btn": "🔗 ソースオーディオに送信",
161
  "save_btn": "💾 保存",
162
  "score_btn": "📊 スコア",
163
+ "lrc_btn": "🎵 LRC",
164
  "quality_score_label": "品質スコア(サンプル {n})",
165
  "quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
166
+ "lrc_label": "歌詞タイムスタンプ(サンプル {n})",
167
+ "lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
168
+ "details_accordion": "📊 スコア & LRC",
169
  "generation_status": "生成ステータス",
170
  "current_batch": "現在のバッチ",
171
  "batch_indicator": "バッチ {current} / {total}",
 
175
  "restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
176
  "batch_results_title": "📁 バッチ結果と生成詳細",
177
  "all_files_label": "📁 すべての生成ファイル(ダウンロード)",
178
+ "generation_details": "生成詳細"
 
 
 
 
179
  },
180
  "messages": {
181
  "no_audio_to_save": "❌ 保存するオーディオがありません",
 
204
  "scoring_failed": "❌ エラー: バッチデータが見つかりません",
205
  "no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
206
  "score_failed": "❌ スコアリングに失敗しました: {error}",
207
+ "score_error": "❌ スコア計算エラー: {error}",
208
+ "lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
209
+ "lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
210
+ "lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
211
+ "lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
212
+ "lrc_empty_result": "⚠️ LRC生成の結果が空です。"
213
  }
214
  }
acestep/gradio_ui/i18n/zh.json CHANGED
@@ -148,8 +148,6 @@
148
  "cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
149
  "score_sensitivity_label": "质量评分敏感度",
150
  "score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
151
- "attention_focus_label": "输出注意力焦点分数(已禁用)",
152
- "attention_focus_info": "输出注意力焦点分数分析",
153
  "think_label": "思考",
154
  "parallel_thinking_label": "并行思考",
155
  "generate_btn": "🎵 生成音乐",
@@ -162,8 +160,12 @@
162
  "send_to_src_btn": "🔗 发送到源音频",
163
  "save_btn": "💾 保存",
164
  "score_btn": "📊 评分",
 
165
  "quality_score_label": "质量分数(样本 {n})",
166
  "quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
 
 
 
167
  "generation_status": "生成状态",
168
  "current_batch": "当前批次",
169
  "batch_indicator": "批次 {current} / {total}",
@@ -173,11 +175,7 @@
173
  "restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
174
  "batch_results_title": "📁 批量结果和生成详情",
175
  "all_files_label": "📁 所有生成的文件(下载)",
176
- "generation_details": "生成详情",
177
- "attention_analysis": "⚖️ 注意力焦点分数分析",
178
- "attention_score": "注意力焦点分数(样本 {n})",
179
- "lyric_timestamps": "歌词时间戳(样本 {n})",
180
- "attention_heatmap": "注意力焦点分数热图(样本 {n})"
181
  },
182
  "messages": {
183
  "no_audio_to_save": "❌ 没有要保存的音频",
@@ -206,6 +204,11 @@
206
  "scoring_failed": "❌ 错误: 未找到批次数据",
207
  "no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
208
  "score_failed": "❌ 评分失败: {error}",
209
- "score_error": "❌ 计算分数时出错: {error}"
 
 
 
 
 
210
  }
211
  }
 
148
  "cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
149
  "score_sensitivity_label": "质量评分敏感度",
150
  "score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
 
 
151
  "think_label": "思考",
152
  "parallel_thinking_label": "并行思考",
153
  "generate_btn": "🎵 生成音乐",
 
160
  "send_to_src_btn": "🔗 发送到源音频",
161
  "save_btn": "💾 保存",
162
  "score_btn": "📊 评分",
163
+ "lrc_btn": "🎵 LRC",
164
  "quality_score_label": "质量分数(样本 {n})",
165
  "quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
166
+ "lrc_label": "歌词时间戳(样本 {n})",
167
+ "lrc_placeholder": "点击'LRC'生成时间戳",
168
+ "details_accordion": "📊 评分与LRC",
169
  "generation_status": "生成状态",
170
  "current_batch": "当前批次",
171
  "batch_indicator": "批次 {current} / {total}",
 
175
  "restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
176
  "batch_results_title": "📁 批量结果和生成详情",
177
  "all_files_label": "📁 所有生成的文件(下载)",
178
+ "generation_details": "生成详情"
 
 
 
 
179
  },
180
  "messages": {
181
  "no_audio_to_save": "❌ 没有要保存的音频",
 
204
  "scoring_failed": "❌ 错误: 未找到批次数据",
205
  "no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
206
  "score_failed": "❌ 评分失败: {error}",
207
+ "score_error": "❌ 计算分数时出错: {error}",
208
+ "lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
209
+ "lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
210
+ "lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
211
+ "lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
212
+ "lrc_empty_result": "⚠️ LRC生成结果为空。"
213
  }
214
  }
acestep/gradio_ui/interfaces/result.py CHANGED
@@ -50,11 +50,24 @@ def create_results_section(dit_handler) -> dict:
50
  size="sm",
51
  scale=1
52
  )
53
- score_display_1 = gr.Textbox(
54
- label=t("results.quality_score_label", n=1),
55
- interactive=False,
56
- placeholder=t("results.quality_score_placeholder")
57
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  with gr.Column(visible=True) as audio_col_2:
59
  generated_audio_2 = gr.Audio(
60
  label=t("results.generated_music", n=2),
@@ -81,11 +94,24 @@ def create_results_section(dit_handler) -> dict:
81
  size="sm",
82
  scale=1
83
  )
84
- score_display_2 = gr.Textbox(
85
- label=t("results.quality_score_label", n=2),
86
- interactive=False,
87
- placeholder=t("results.quality_score_placeholder")
88
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  with gr.Column(visible=False) as audio_col_3:
90
  generated_audio_3 = gr.Audio(
91
  label=t("results.generated_music", n=3),
@@ -112,11 +138,24 @@ def create_results_section(dit_handler) -> dict:
112
  size="sm",
113
  scale=1
114
  )
115
- score_display_3 = gr.Textbox(
116
- label=t("results.quality_score_label", n=3),
117
- interactive=False,
118
- placeholder=t("results.quality_score_placeholder")
119
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  with gr.Column(visible=False) as audio_col_4:
121
  generated_audio_4 = gr.Audio(
122
  label=t("results.generated_music", n=4),
@@ -143,11 +182,24 @@ def create_results_section(dit_handler) -> dict:
143
  size="sm",
144
  scale=1
145
  )
146
- score_display_4 = gr.Textbox(
147
- label=t("results.quality_score_label", n=4),
148
- interactive=False,
149
- placeholder=t("results.quality_score_placeholder")
150
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  # Second row for batch size 5-8 (initially hidden)
153
  with gr.Row(visible=False) as audio_row_5_8:
@@ -162,11 +214,19 @@ def create_results_section(dit_handler) -> dict:
162
  send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
163
  save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
164
  score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
165
- score_display_5 = gr.Textbox(
166
- label=t("results.quality_score_label", n=5),
167
- interactive=False,
168
- placeholder=t("results.quality_score_placeholder")
169
- )
 
 
 
 
 
 
 
 
170
  with gr.Column() as audio_col_6:
171
  generated_audio_6 = gr.Audio(
172
  label=t("results.generated_music", n=6),
@@ -178,11 +238,19 @@ def create_results_section(dit_handler) -> dict:
178
  send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
179
  save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
180
  score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
181
- score_display_6 = gr.Textbox(
182
- label=t("results.quality_score_label", n=6),
183
- interactive=False,
184
- placeholder=t("results.quality_score_placeholder")
185
- )
 
 
 
 
 
 
 
 
186
  with gr.Column() as audio_col_7:
187
  generated_audio_7 = gr.Audio(
188
  label=t("results.generated_music", n=7),
@@ -194,11 +262,19 @@ def create_results_section(dit_handler) -> dict:
194
  send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
195
  save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
196
  score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
197
- score_display_7 = gr.Textbox(
198
- label=t("results.quality_score_label", n=7),
199
- interactive=False,
200
- placeholder=t("results.quality_score_placeholder")
201
- )
 
 
 
 
 
 
 
 
202
  with gr.Column() as audio_col_8:
203
  generated_audio_8 = gr.Audio(
204
  label=t("results.generated_music", n=8),
@@ -210,11 +286,19 @@ def create_results_section(dit_handler) -> dict:
210
  send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
211
  save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
212
  score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
213
- score_display_8 = gr.Textbox(
214
- label=t("results.quality_score_label", n=8),
215
- interactive=False,
216
- placeholder=t("results.quality_score_placeholder")
217
- )
 
 
 
 
 
 
 
 
218
 
219
  status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
220
 
@@ -262,17 +346,6 @@ def create_results_section(dit_handler) -> dict:
262
  interactive=False
263
  )
264
  generation_info = gr.Markdown(label=t("results.generation_details"))
265
-
266
- with gr.Accordion(t("results.attention_analysis"), open=False):
267
- with gr.Row():
268
- with gr.Column():
269
- align_score_1 = gr.Textbox(label=t("results.attention_score", n=1), interactive=False)
270
- align_text_1 = gr.Textbox(label=t("results.lyric_timestamps", n=1), interactive=False, lines=10)
271
- align_plot_1 = gr.Plot(label=t("results.attention_heatmap", n=1))
272
- with gr.Column():
273
- align_score_2 = gr.Textbox(label=t("results.attention_score", n=2), interactive=False)
274
- align_text_2 = gr.Textbox(label=t("results.lyric_timestamps", n=2), interactive=False, lines=10)
275
- align_plot_2 = gr.Plot(label=t("results.attention_heatmap", n=2))
276
 
277
  return {
278
  "lm_metadata_state": lm_metadata_state,
@@ -337,13 +410,31 @@ def create_results_section(dit_handler) -> dict:
337
  "score_display_6": score_display_6,
338
  "score_display_7": score_display_7,
339
  "score_display_8": score_display_8,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  "generated_audio_batch": generated_audio_batch,
341
  "generation_info": generation_info,
342
- "align_score_1": align_score_1,
343
- "align_text_1": align_text_1,
344
- "align_plot_1": align_plot_1,
345
- "align_score_2": align_score_2,
346
- "align_text_2": align_text_2,
347
- "align_plot_2": align_plot_2,
348
  }
349
 
 
50
  size="sm",
51
  scale=1
52
  )
53
+ lrc_btn_1 = gr.Button(
54
+ t("results.lrc_btn"),
55
+ variant="secondary",
56
+ size="sm",
57
+ scale=1
58
+ )
59
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_1:
60
+ score_display_1 = gr.Textbox(
61
+ label=t("results.quality_score_label", n=1),
62
+ interactive=False,
63
+ visible=False
64
+ )
65
+ lrc_display_1 = gr.Textbox(
66
+ label=t("results.lrc_label", n=1),
67
+ interactive=False,
68
+ lines=8,
69
+ visible=False
70
+ )
71
  with gr.Column(visible=True) as audio_col_2:
72
  generated_audio_2 = gr.Audio(
73
  label=t("results.generated_music", n=2),
 
94
  size="sm",
95
  scale=1
96
  )
97
+ lrc_btn_2 = gr.Button(
98
+ t("results.lrc_btn"),
99
+ variant="secondary",
100
+ size="sm",
101
+ scale=1
102
+ )
103
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_2:
104
+ score_display_2 = gr.Textbox(
105
+ label=t("results.quality_score_label", n=2),
106
+ interactive=False,
107
+ visible=False
108
+ )
109
+ lrc_display_2 = gr.Textbox(
110
+ label=t("results.lrc_label", n=2),
111
+ interactive=False,
112
+ lines=8,
113
+ visible=False
114
+ )
115
  with gr.Column(visible=False) as audio_col_3:
116
  generated_audio_3 = gr.Audio(
117
  label=t("results.generated_music", n=3),
 
138
  size="sm",
139
  scale=1
140
  )
141
+ lrc_btn_3 = gr.Button(
142
+ t("results.lrc_btn"),
143
+ variant="secondary",
144
+ size="sm",
145
+ scale=1
146
+ )
147
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_3:
148
+ score_display_3 = gr.Textbox(
149
+ label=t("results.quality_score_label", n=3),
150
+ interactive=False,
151
+ visible=False
152
+ )
153
+ lrc_display_3 = gr.Textbox(
154
+ label=t("results.lrc_label", n=3),
155
+ interactive=False,
156
+ lines=8,
157
+ visible=False
158
+ )
159
  with gr.Column(visible=False) as audio_col_4:
160
  generated_audio_4 = gr.Audio(
161
  label=t("results.generated_music", n=4),
 
182
  size="sm",
183
  scale=1
184
  )
185
+ lrc_btn_4 = gr.Button(
186
+ t("results.lrc_btn"),
187
+ variant="secondary",
188
+ size="sm",
189
+ scale=1
190
+ )
191
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_4:
192
+ score_display_4 = gr.Textbox(
193
+ label=t("results.quality_score_label", n=4),
194
+ interactive=False,
195
+ visible=False
196
+ )
197
+ lrc_display_4 = gr.Textbox(
198
+ label=t("results.lrc_label", n=4),
199
+ interactive=False,
200
+ lines=8,
201
+ visible=False
202
+ )
203
 
204
  # Second row for batch size 5-8 (initially hidden)
205
  with gr.Row(visible=False) as audio_row_5_8:
 
214
  send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
215
  save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
216
  score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
217
+ lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
218
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_5:
219
+ score_display_5 = gr.Textbox(
220
+ label=t("results.quality_score_label", n=5),
221
+ interactive=False,
222
+ visible=False
223
+ )
224
+ lrc_display_5 = gr.Textbox(
225
+ label=t("results.lrc_label", n=5),
226
+ interactive=False,
227
+ lines=8,
228
+ visible=False
229
+ )
230
  with gr.Column() as audio_col_6:
231
  generated_audio_6 = gr.Audio(
232
  label=t("results.generated_music", n=6),
 
238
  send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
239
  save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
240
  score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
241
+ lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
242
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_6:
243
+ score_display_6 = gr.Textbox(
244
+ label=t("results.quality_score_label", n=6),
245
+ interactive=False,
246
+ visible=False
247
+ )
248
+ lrc_display_6 = gr.Textbox(
249
+ label=t("results.lrc_label", n=6),
250
+ interactive=False,
251
+ lines=8,
252
+ visible=False
253
+ )
254
  with gr.Column() as audio_col_7:
255
  generated_audio_7 = gr.Audio(
256
  label=t("results.generated_music", n=7),
 
262
  send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
263
  save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
264
  score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
265
+ lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
266
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_7:
267
+ score_display_7 = gr.Textbox(
268
+ label=t("results.quality_score_label", n=7),
269
+ interactive=False,
270
+ visible=False
271
+ )
272
+ lrc_display_7 = gr.Textbox(
273
+ label=t("results.lrc_label", n=7),
274
+ interactive=False,
275
+ lines=8,
276
+ visible=False
277
+ )
278
  with gr.Column() as audio_col_8:
279
  generated_audio_8 = gr.Audio(
280
  label=t("results.generated_music", n=8),
 
286
  send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
287
  save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
288
  score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
289
+ lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
290
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=False) as details_accordion_8:
291
+ score_display_8 = gr.Textbox(
292
+ label=t("results.quality_score_label", n=8),
293
+ interactive=False,
294
+ visible=False
295
+ )
296
+ lrc_display_8 = gr.Textbox(
297
+ label=t("results.lrc_label", n=8),
298
+ interactive=False,
299
+ lines=8,
300
+ visible=False
301
+ )
302
 
303
  status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
304
 
 
346
  interactive=False
347
  )
348
  generation_info = gr.Markdown(label=t("results.generation_details"))
 
 
 
 
 
 
 
 
 
 
 
349
 
350
  return {
351
  "lm_metadata_state": lm_metadata_state,
 
410
  "score_display_6": score_display_6,
411
  "score_display_7": score_display_7,
412
  "score_display_8": score_display_8,
413
+ "lrc_btn_1": lrc_btn_1,
414
+ "lrc_btn_2": lrc_btn_2,
415
+ "lrc_btn_3": lrc_btn_3,
416
+ "lrc_btn_4": lrc_btn_4,
417
+ "lrc_btn_5": lrc_btn_5,
418
+ "lrc_btn_6": lrc_btn_6,
419
+ "lrc_btn_7": lrc_btn_7,
420
+ "lrc_btn_8": lrc_btn_8,
421
+ "lrc_display_1": lrc_display_1,
422
+ "lrc_display_2": lrc_display_2,
423
+ "lrc_display_3": lrc_display_3,
424
+ "lrc_display_4": lrc_display_4,
425
+ "lrc_display_5": lrc_display_5,
426
+ "lrc_display_6": lrc_display_6,
427
+ "lrc_display_7": lrc_display_7,
428
+ "lrc_display_8": lrc_display_8,
429
+ "details_accordion_1": details_accordion_1,
430
+ "details_accordion_2": details_accordion_2,
431
+ "details_accordion_3": details_accordion_3,
432
+ "details_accordion_4": details_accordion_4,
433
+ "details_accordion_5": details_accordion_5,
434
+ "details_accordion_6": details_accordion_6,
435
+ "details_accordion_7": details_accordion_7,
436
+ "details_accordion_8": details_accordion_8,
437
  "generated_audio_batch": generated_audio_batch,
438
  "generation_info": generation_info,
 
 
 
 
 
 
439
  }
440
 
acestep/handler.py CHANGED
@@ -31,6 +31,7 @@ from acestep.constants import (
31
  SFT_GEN_PROMPT,
32
  DEFAULT_DIT_INSTRUCTION,
33
  )
 
34
 
35
 
36
  warnings.filterwarnings("ignore")
@@ -65,13 +66,7 @@ class AceStepHandler:
65
  self.batch_size = 2
66
 
67
  # Custom layers config
68
- self.custom_layers_config = {
69
- 2: [6, 7],
70
- 3: [10, 11],
71
- 4: [3],
72
- 5: [8, 9, 11],
73
- 6: [8]
74
- }
75
  self.offload_to_cpu = False
76
  self.offload_dit_to_cpu = False
77
  self.current_offload_cost = 0.0
@@ -1953,6 +1948,23 @@ class AceStepHandler:
1953
  }
1954
  logger.info("[service_generate] Generating audio...")
1955
  with self._load_model_context("model"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1956
  outputs = self.model.generate_audio(**generate_kwargs)
1957
 
1958
  # Add intermediate information to outputs for extra_outputs
@@ -1962,6 +1974,12 @@ class AceStepHandler:
1962
  outputs["spans"] = spans
1963
  outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
1964
 
 
 
 
 
 
 
1965
  return outputs
1966
 
1967
  def tiled_decode(self, latents, chunk_size=512, overlap=64):
@@ -2268,16 +2286,27 @@ class AceStepHandler:
2268
  spans = outputs.get("spans", []) # List of tuples
2269
  latent_masks = outputs.get("latent_masks") # [batch, T]
2270
 
2271
- # Move latents to CPU to save memory (they can be large)
 
 
 
 
 
 
2272
  extra_outputs = {
2273
- "pred_latents": pred_latents.cpu() if pred_latents is not None else None,
2274
- "target_latents": target_latents_input.cpu() if target_latents_input is not None else None,
2275
- "src_latents": src_latents.cpu() if src_latents is not None else None,
2276
- "chunk_masks": chunk_masks.cpu() if chunk_masks is not None else None,
2277
- "latent_masks": latent_masks.cpu() if latent_masks is not None else None,
2278
  "spans": spans,
2279
  "time_costs": time_costs,
2280
  "seed_value": seed_value_for_ui,
 
 
 
 
 
2281
  }
2282
 
2283
  # Build audios list with tensor data (no file paths, no UUIDs, handled outside)
@@ -2307,3 +2336,220 @@ class AceStepHandler:
2307
  "success": False,
2308
  "error": str(e),
2309
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  SFT_GEN_PROMPT,
32
  DEFAULT_DIT_INSTRUCTION,
33
  )
34
+ from acestep.dit_alignment_score import MusicStampsAligner
35
 
36
 
37
  warnings.filterwarnings("ignore")
 
66
  self.batch_size = 2
67
 
68
  # Custom layers config
69
+ self.custom_layers_config = {2: [6], 3: [10, 11], 4: [3], 5: [8, 9], 6: [8]}
 
 
 
 
 
 
70
  self.offload_to_cpu = False
71
  self.offload_dit_to_cpu = False
72
  self.current_offload_cost = 0.0
 
1948
  }
1949
  logger.info("[service_generate] Generating audio...")
1950
  with self._load_model_context("model"):
1951
+ # Prepare condition tensors first (for LRC timestamp generation)
1952
+ encoder_hidden_states, encoder_attention_mask, context_latents = self.model.prepare_condition(
1953
+ text_hidden_states=text_hidden_states,
1954
+ text_attention_mask=text_attention_mask,
1955
+ lyric_hidden_states=lyric_hidden_states,
1956
+ lyric_attention_mask=lyric_attention_mask,
1957
+ refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
1958
+ refer_audio_order_mask=refer_audio_order_mask,
1959
+ hidden_states=src_latents,
1960
+ attention_mask=torch.ones(src_latents.shape[0], src_latents.shape[1], device=src_latents.device, dtype=src_latents.dtype),
1961
+ silence_latent=self.silence_latent,
1962
+ src_latents=src_latents,
1963
+ chunk_masks=chunk_mask,
1964
+ is_covers=is_covers,
1965
+ precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz,
1966
+ )
1967
+
1968
  outputs = self.model.generate_audio(**generate_kwargs)
1969
 
1970
  # Add intermediate information to outputs for extra_outputs
 
1974
  outputs["spans"] = spans
1975
  outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
1976
 
1977
+ # Add condition tensors for LRC timestamp generation
1978
+ outputs["encoder_hidden_states"] = encoder_hidden_states
1979
+ outputs["encoder_attention_mask"] = encoder_attention_mask
1980
+ outputs["context_latents"] = context_latents
1981
+ outputs["lyric_token_idss"] = lyric_token_idss
1982
+
1983
  return outputs
1984
 
1985
  def tiled_decode(self, latents, chunk_size=512, overlap=64):
 
2286
  spans = outputs.get("spans", []) # List of tuples
2287
  latent_masks = outputs.get("latent_masks") # [batch, T]
2288
 
2289
+ # Extract condition tensors for LRC timestamp generation
2290
+ encoder_hidden_states = outputs.get("encoder_hidden_states")
2291
+ encoder_attention_mask = outputs.get("encoder_attention_mask")
2292
+ context_latents = outputs.get("context_latents")
2293
+ lyric_token_idss = outputs.get("lyric_token_idss")
2294
+
2295
+ # Move all tensors to CPU to save VRAM (detach to release computation graph)
2296
  extra_outputs = {
2297
+ "pred_latents": pred_latents.detach().cpu() if pred_latents is not None else None,
2298
+ "target_latents": target_latents_input.detach().cpu() if target_latents_input is not None else None,
2299
+ "src_latents": src_latents.detach().cpu() if src_latents is not None else None,
2300
+ "chunk_masks": chunk_masks.detach().cpu() if chunk_masks is not None else None,
2301
+ "latent_masks": latent_masks.detach().cpu() if latent_masks is not None else None,
2302
  "spans": spans,
2303
  "time_costs": time_costs,
2304
  "seed_value": seed_value_for_ui,
2305
+ # Condition tensors for LRC timestamp generation
2306
+ "encoder_hidden_states": encoder_hidden_states.detach().cpu() if encoder_hidden_states is not None else None,
2307
+ "encoder_attention_mask": encoder_attention_mask.detach().cpu() if encoder_attention_mask is not None else None,
2308
+ "context_latents": context_latents.detach().cpu() if context_latents is not None else None,
2309
+ "lyric_token_idss": lyric_token_idss.detach().cpu() if lyric_token_idss is not None else None,
2310
  }
2311
 
2312
  # Build audios list with tensor data (no file paths, no UUIDs, handled outside)
 
2336
  "success": False,
2337
  "error": str(e),
2338
  }
2339
+
2340
+ @torch.no_grad()
2341
+ def get_lyric_timestamp(
2342
+ self,
2343
+ pred_latent: torch.Tensor,
2344
+ encoder_hidden_states: torch.Tensor,
2345
+ encoder_attention_mask: torch.Tensor,
2346
+ context_latents: torch.Tensor,
2347
+ lyric_token_ids: torch.Tensor,
2348
+ total_duration_seconds: float,
2349
+ vocal_language: str = "en",
2350
+ inference_steps: int = 8,
2351
+ seed: int = 42,
2352
+ custom_layers_config: Optional[Dict] = None,
2353
+ ) -> Dict[str, Any]:
2354
+ """
2355
+ Generate lyrics timestamps from generated audio latents using cross-attention alignment.
2356
+
2357
+ This method adds noise to the final pred_latent and re-infers one step to get
2358
+ cross-attention matrices, then uses DTW to align lyrics tokens with audio frames.
2359
+
2360
+ Args:
2361
+ pred_latent: Generated latent tensor [batch, T, D]
2362
+ encoder_hidden_states: Cached encoder hidden states
2363
+ encoder_attention_mask: Cached encoder attention mask
2364
+ context_latents: Cached context latents
2365
+ lyric_token_ids: Tokenized lyrics tensor [batch, seq_len]
2366
+ total_duration_seconds: Total audio duration in seconds
2367
+ vocal_language: Language code for lyrics header parsing
2368
+ inference_steps: Number of inference steps (for noise level calculation)
2369
+ seed: Random seed for noise generation
2370
+ custom_layers_config: Dict mapping layer indices to head indices
2371
+
2372
+ Returns:
2373
+ Dict containing:
2374
+ - lrc_text: LRC formatted lyrics with timestamps
2375
+ - sentence_timestamps: List of SentenceTimestamp objects
2376
+ - token_timestamps: List of TokenTimestamp objects
2377
+ - success: Whether generation succeeded
2378
+ - error: Error message if failed
2379
+ """
2380
+ from transformers.cache_utils import EncoderDecoderCache, DynamicCache
2381
+
2382
+ if self.model is None:
2383
+ return {
2384
+ "lrc_text": "",
2385
+ "sentence_timestamps": [],
2386
+ "token_timestamps": [],
2387
+ "success": False,
2388
+ "error": "Model not initialized"
2389
+ }
2390
+
2391
+ if custom_layers_config is None:
2392
+ custom_layers_config = self.custom_layers_config
2393
+
2394
+ try:
2395
+ # Move tensors to device
2396
+ device = self.device
2397
+ dtype = self.dtype
2398
+
2399
+ pred_latent = pred_latent.to(device=device, dtype=dtype)
2400
+ encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype)
2401
+ encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype)
2402
+ context_latents = context_latents.to(device=device, dtype=dtype)
2403
+
2404
+ bsz = pred_latent.shape[0]
2405
+
2406
+ # Calculate noise level: t_last = 1.0 / inference_steps
2407
+ t_last_val = 1.0 / inference_steps
2408
+ t_curr_tensor = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype)
2409
+
2410
+ x1 = pred_latent
2411
+
2412
+ # Generate noise
2413
+ if seed is None:
2414
+ x0 = torch.randn_like(x1)
2415
+ else:
2416
+ generator = torch.Generator(device=device).manual_seed(int(seed))
2417
+ x0 = torch.randn(x1.shape, generator=generator, device=device, dtype=dtype)
2418
+
2419
+ # Add noise to pred_latent: xt = t * noise + (1 - t) * x1
2420
+ xt = t_last_val * x0 + (1.0 - t_last_val) * x1
2421
+
2422
+ xt_in = xt
2423
+ t_in = t_curr_tensor
2424
+
2425
+ # Get null condition embedding
2426
+ encoder_hidden_states_in = encoder_hidden_states
2427
+ encoder_attention_mask_in = encoder_attention_mask
2428
+ context_latents_in = context_latents
2429
+ latent_length = x1.shape[1]
2430
+ attention_mask = torch.ones(bsz, latent_length, device=device, dtype=dtype)
2431
+ attention_mask_in = attention_mask
2432
+ past_key_values = None
2433
+
2434
+ # Run decoder with output_attentions=True
2435
+ with self._load_model_context("model"):
2436
+ decoder = self.model.decoder
2437
+ decoder_outputs = decoder(
2438
+ hidden_states=xt_in,
2439
+ timestep=t_in,
2440
+ timestep_r=t_in,
2441
+ attention_mask=attention_mask_in,
2442
+ encoder_hidden_states=encoder_hidden_states_in,
2443
+ use_cache=False,
2444
+ past_key_values=past_key_values,
2445
+ encoder_attention_mask=encoder_attention_mask_in,
2446
+ context_latents=context_latents_in,
2447
+ output_attentions=True,
2448
+ custom_layers_config=custom_layers_config,
2449
+ enable_early_exit=True
2450
+ )
2451
+
2452
+ # Extract cross-attention matrices
2453
+ if decoder_outputs[2] is None:
2454
+ return {
2455
+ "lrc_text": "",
2456
+ "sentence_timestamps": [],
2457
+ "token_timestamps": [],
2458
+ "success": False,
2459
+ "error": "Model did not return attentions"
2460
+ }
2461
+
2462
+ cross_attns = decoder_outputs[2] # Tuple of tensors (some may be None)
2463
+
2464
+ captured_layers_list = []
2465
+ for layer_attn in cross_attns:
2466
+ # Skip None values (layers that didn't return attention)
2467
+ if layer_attn is None:
2468
+ continue
2469
+ # Only take conditional part (first half of batch)
2470
+ cond_attn = layer_attn[:bsz]
2471
+ layer_matrix = cond_attn.transpose(-1, -2)
2472
+ captured_layers_list.append(layer_matrix)
2473
+
2474
+ if not captured_layers_list:
2475
+ return {
2476
+ "lrc_text": "",
2477
+ "sentence_timestamps": [],
2478
+ "token_timestamps": [],
2479
+ "success": False,
2480
+ "error": "No valid attention layers returned"
2481
+ }
2482
+
2483
+ stacked = torch.stack(captured_layers_list)
2484
+ if bsz == 1:
2485
+ all_layers_matrix = stacked.squeeze(1)
2486
+ else:
2487
+ all_layers_matrix = stacked
2488
+
2489
+ # Process lyric token IDs to extract pure lyrics
2490
+ if isinstance(lyric_token_ids, torch.Tensor):
2491
+ raw_lyric_ids = lyric_token_ids[0].tolist()
2492
+ else:
2493
+ raw_lyric_ids = lyric_token_ids
2494
+
2495
+ # Parse header to find lyrics start position
2496
+ header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n"
2497
+ header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False)
2498
+ start_idx = len(header_ids)
2499
+
2500
+ # Find end of lyrics (before endoftext token)
2501
+ try:
2502
+ end_idx = raw_lyric_ids.index(151643) # <|endoftext|> token
2503
+ except ValueError:
2504
+ end_idx = len(raw_lyric_ids)
2505
+
2506
+ pure_lyric_ids = raw_lyric_ids[start_idx:end_idx]
2507
+ pure_lyric_matrix = all_layers_matrix[:, :, start_idx:end_idx, :]
2508
+
2509
+ # Create aligner and generate timestamps
2510
+ aligner = MusicStampsAligner(self.text_tokenizer)
2511
+
2512
+ align_info = aligner.stamps_align_info(
2513
+ attention_matrix=pure_lyric_matrix,
2514
+ lyrics_tokens=pure_lyric_ids,
2515
+ total_duration_seconds=total_duration_seconds,
2516
+ custom_config=custom_layers_config,
2517
+ return_matrices=False,
2518
+ violence_level=2.0,
2519
+ medfilt_width=1,
2520
+ )
2521
+
2522
+ if align_info.get("calc_matrix") is None:
2523
+ return {
2524
+ "lrc_text": "",
2525
+ "sentence_timestamps": [],
2526
+ "token_timestamps": [],
2527
+ "success": False,
2528
+ "error": align_info.get("error", "Failed to process attention matrix")
2529
+ }
2530
+
2531
+ # Generate timestamps
2532
+ result = aligner.get_timestamps_and_lrc(
2533
+ calc_matrix=align_info["calc_matrix"],
2534
+ lyrics_tokens=pure_lyric_ids,
2535
+ total_duration_seconds=total_duration_seconds
2536
+ )
2537
+
2538
+ return {
2539
+ "lrc_text": result["lrc_text"],
2540
+ "sentence_timestamps": result["sentence_timestamps"],
2541
+ "token_timestamps": result["token_timestamps"],
2542
+ "success": True,
2543
+ "error": None
2544
+ }
2545
+
2546
+ except Exception as e:
2547
+ error_msg = f"Error generating timestamps: {str(e)}"
2548
+ logger.exception("[get_lyric_timestamp] Failed")
2549
+ return {
2550
+ "lrc_text": "",
2551
+ "sentence_timestamps": [],
2552
+ "token_timestamps": [],
2553
+ "success": False,
2554
+ "error": error_msg
2555
+ }
pyproject.toml CHANGED
@@ -30,7 +30,7 @@ dependencies = [
30
  "uvicorn[standard]>=0.27.0",
31
 
32
  # Local third-party packages
33
- "nano-vllm @ file:///${PROJECT_ROOT}/acestep/third_parts/nano-vllm",
34
  ]
35
 
36
  [project.scripts]
@@ -41,8 +41,8 @@ acestep-api = "acestep.api_server:main"
41
  requires = ["hatchling"]
42
  build-backend = "hatchling.build"
43
 
44
- [tool.uv]
45
- dev-dependencies = []
46
 
47
  [[tool.uv.index]]
48
  name = "pytorch"
 
30
  "uvicorn[standard]>=0.27.0",
31
 
32
  # Local third-party packages
33
+ "nano-vllm @ {root:uri}/acestep/third_parts/nano-vllm",
34
  ]
35
 
36
  [project.scripts]
 
41
  requires = ["hatchling"]
42
  build-backend = "hatchling.build"
43
 
44
+ [dependency-groups]
45
+ dev = []
46
 
47
  [[tool.uv.index]]
48
  name = "pytorch"