ChuxiJ commited on
Commit
11860f1
·
1 Parent(s): 24f370e

add inference code and doc

Browse files
acestep/api_server.py CHANGED
@@ -868,7 +868,7 @@ def create_app() -> FastAPI:
868
  if s in {"", "N/A"}:
869
  return None
870
  return s
871
- first, second, paths, gen_info, status_msg, seed_value, *_ = h.generate_music(
872
  captions=req.caption,
873
  lyrics=req.lyrics,
874
  bpm=bpm_val,
@@ -896,10 +896,20 @@ def create_app() -> FastAPI:
896
  use_tiled_decode=req.use_tiled_decode,
897
  progress=None,
898
  )
 
 
 
 
 
 
 
 
 
 
899
  return {
900
  "first_audio_path": _path_to_audio_url(first) if first else None,
901
  "second_audio_path": _path_to_audio_url(second) if second else None,
902
- "audio_paths": [_path_to_audio_url(p) for p in (paths or [])],
903
  "generation_info": gen_info,
904
  "status_message": status_msg,
905
  "seed_value": seed_value,
 
868
  if s in {"", "N/A"}:
869
  return None
870
  return s
871
+ result = h.generate_music(
872
  captions=req.caption,
873
  lyrics=req.lyrics,
874
  bpm=bpm_val,
 
896
  use_tiled_decode=req.use_tiled_decode,
897
  progress=None,
898
  )
899
+
900
+ # Extract values from new dict structure
901
+ audios = result.get("audios", [])
902
+ audio_paths = [audio.get("path") for audio in audios]
903
+ first = audio_paths[0] if len(audio_paths) > 0 else None
904
+ second = audio_paths[1] if len(audio_paths) > 1 else None
905
+ gen_info = result.get("generation_info", "")
906
+ status_msg = result.get("status_message", "")
907
+ seed_value = result.get("extra_outputs", {}).get("seed_value", "")
908
+
909
  return {
910
  "first_audio_path": _path_to_audio_url(first) if first else None,
911
  "second_audio_path": _path_to_audio_url(second) if second else None,
912
+ "audio_paths": [_path_to_audio_url(p) for p in (audio_paths or [])],
913
  "generation_info": gen_info,
914
  "status_message": status_msg,
915
  "seed_value": seed_value,
acestep/audio_utils.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio saving and transcoding utility module
3
+
4
+ Independent audio file operations outside of handler, supporting:
5
+ - Save audio tensor/numpy to files (default FLAC format, fast)
6
+ - Format conversion (FLAC/WAV/MP3)
7
+ - Batch processing
8
+ """
9
+
10
+ import os
11
+ import hashlib
12
+ import json
13
+ from pathlib import Path
14
+ from typing import Union, Optional, List, Tuple
15
+ import torch
16
+ import numpy as np
17
+ import torchaudio
18
+ from loguru import logger
19
+
20
+
21
+ class AudioSaver:
22
+ """Audio saving and transcoding utility class"""
23
+
24
+ def __init__(self, default_format: str = "flac"):
25
+ """
26
+ Initialize audio saver
27
+
28
+ Args:
29
+ default_format: Default save format ('flac', 'wav', 'mp3')
30
+ """
31
+ self.default_format = default_format.lower()
32
+ if self.default_format not in ["flac", "wav", "mp3"]:
33
+ logger.warning(f"Unsupported format {default_format}, using 'flac'")
34
+ self.default_format = "flac"
35
+
36
+ def save_audio(
37
+ self,
38
+ audio_data: Union[torch.Tensor, np.ndarray],
39
+ output_path: Union[str, Path],
40
+ sample_rate: int = 48000,
41
+ format: Optional[str] = None,
42
+ channels_first: bool = True,
43
+ ) -> str:
44
+ """
45
+ Save audio data to file
46
+
47
+ Args:
48
+ audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
49
+ output_path: Output file path (extension can be omitted)
50
+ sample_rate: Sample rate
51
+ format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
52
+ channels_first: If True, tensor format is [channels, samples], else [samples, channels]
53
+
54
+ Returns:
55
+ Actual saved file path
56
+ """
57
+ format = (format or self.default_format).lower()
58
+ if format not in ["flac", "wav", "mp3"]:
59
+ logger.warning(f"Unsupported format {format}, using {self.default_format}")
60
+ format = self.default_format
61
+
62
+ # Ensure output path has correct extension
63
+ output_path = Path(output_path)
64
+ if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
65
+ output_path = output_path.with_suffix(f'.{format}')
66
+
67
+ # Convert to torch tensor
68
+ if isinstance(audio_data, np.ndarray):
69
+ if channels_first:
70
+ # numpy [samples, channels] -> tensor [channels, samples]
71
+ audio_tensor = torch.from_numpy(audio_data.T).float()
72
+ else:
73
+ # numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
74
+ audio_tensor = torch.from_numpy(audio_data).float()
75
+ if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
76
+ audio_tensor = audio_tensor.T
77
+ else:
78
+ # torch tensor
79
+ audio_tensor = audio_data.cpu().float()
80
+ if not channels_first and audio_tensor.dim() == 2:
81
+ # [samples, channels] -> [channels, samples]
82
+ if audio_tensor.shape[0] > audio_tensor.shape[1]:
83
+ audio_tensor = audio_tensor.T
84
+
85
+ # Ensure memory is contiguous
86
+ audio_tensor = audio_tensor.contiguous()
87
+
88
+ # Select backend and save
89
+ try:
90
+ if format == "mp3":
91
+ # MP3 uses ffmpeg backend
92
+ from torchaudio.io import CodecConfig
93
+ config = CodecConfig(bit_rate=192000, compression_level=1)
94
+ torchaudio.save(
95
+ str(output_path),
96
+ audio_tensor,
97
+ sample_rate,
98
+ channels_first=True,
99
+ backend='ffmpeg',
100
+ compression=config,
101
+ buffer_size=65536
102
+ )
103
+ elif format in ["flac", "wav"]:
104
+ # FLAC and WAV use soundfile backend (fastest)
105
+ torchaudio.save(
106
+ str(output_path),
107
+ audio_tensor,
108
+ sample_rate,
109
+ channels_first=True,
110
+ backend='soundfile',
111
+ buffer_size=65536
112
+ )
113
+ else:
114
+ # Other formats use default backend
115
+ torchaudio.save(
116
+ str(output_path),
117
+ audio_tensor,
118
+ sample_rate,
119
+ channels_first=True,
120
+ buffer_size=65536
121
+ )
122
+
123
+ logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
124
+ return str(output_path)
125
+
126
+ except Exception as e:
127
+ logger.error(f"[AudioSaver] Failed to save audio: {e}")
128
+ raise
129
+
130
+ def convert_audio(
131
+ self,
132
+ input_path: Union[str, Path],
133
+ output_path: Union[str, Path],
134
+ output_format: str,
135
+ remove_input: bool = False,
136
+ ) -> str:
137
+ """
138
+ Convert audio format
139
+
140
+ Args:
141
+ input_path: Input audio file path
142
+ output_path: Output audio file path
143
+ output_format: Target format ('flac', 'wav', 'mp3')
144
+ remove_input: Whether to delete input file
145
+
146
+ Returns:
147
+ Output file path
148
+ """
149
+ input_path = Path(input_path)
150
+ output_path = Path(output_path)
151
+
152
+ if not input_path.exists():
153
+ raise FileNotFoundError(f"Input file not found: {input_path}")
154
+
155
+ # Load audio
156
+ audio_tensor, sample_rate = torchaudio.load(str(input_path))
157
+
158
+ # Save as new format
159
+ output_path = self.save_audio(
160
+ audio_tensor,
161
+ output_path,
162
+ sample_rate=sample_rate,
163
+ format=output_format,
164
+ channels_first=True
165
+ )
166
+
167
+ # Delete input file if needed
168
+ if remove_input:
169
+ input_path.unlink()
170
+ logger.debug(f"[AudioSaver] Removed input file: {input_path}")
171
+
172
+ return output_path
173
+
174
+ def save_batch(
175
+ self,
176
+ audio_batch: Union[List[torch.Tensor], torch.Tensor],
177
+ output_dir: Union[str, Path],
178
+ file_prefix: str = "audio",
179
+ sample_rate: int = 48000,
180
+ format: Optional[str] = None,
181
+ channels_first: bool = True,
182
+ ) -> List[str]:
183
+ """
184
+ Save audio batch
185
+
186
+ Args:
187
+ audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
188
+ output_dir: Output directory
189
+ file_prefix: File prefix
190
+ sample_rate: Sample rate
191
+ format: Audio format
192
+ channels_first: Tensor format flag
193
+
194
+ Returns:
195
+ List of saved file paths
196
+ """
197
+ output_dir = Path(output_dir)
198
+ output_dir.mkdir(parents=True, exist_ok=True)
199
+
200
+ # Process batch
201
+ if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
202
+ # [batch, channels, samples]
203
+ audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
204
+ elif isinstance(audio_batch, list):
205
+ audio_list = audio_batch
206
+ else:
207
+ audio_list = [audio_batch]
208
+
209
+ saved_paths = []
210
+ for i, audio in enumerate(audio_list):
211
+ output_path = output_dir / f"{file_prefix}_{i:04d}"
212
+ saved_path = self.save_audio(
213
+ audio,
214
+ output_path,
215
+ sample_rate=sample_rate,
216
+ format=format,
217
+ channels_first=channels_first
218
+ )
219
+ saved_paths.append(saved_path)
220
+
221
+ return saved_paths
222
+
223
+
224
+ def get_audio_file_hash(audio_file) -> str:
225
+ """
226
+ Get hash identifier for an audio file.
227
+
228
+ Args:
229
+ audio_file: Path to audio file (str) or file-like object
230
+
231
+ Returns:
232
+ Hash string or empty string
233
+ """
234
+ if audio_file is None:
235
+ return ""
236
+
237
+ try:
238
+ if isinstance(audio_file, str):
239
+ if os.path.exists(audio_file):
240
+ with open(audio_file, 'rb') as f:
241
+ return hashlib.md5(f.read()).hexdigest()
242
+ return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
243
+ elif hasattr(audio_file, 'name'):
244
+ return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
245
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
246
+ except Exception:
247
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
248
+
249
+
250
+ def generate_uuid_from_params(
251
+ captions: str,
252
+ lyrics: str,
253
+ bpm: Optional[int],
254
+ key_scale: str,
255
+ time_signature: str,
256
+ vocal_language: str,
257
+ inference_steps: int,
258
+ guidance_scale: float,
259
+ seed: Union[str, float, int],
260
+ audio_duration: Optional[float],
261
+ audio_code_string: Union[str, List[str]],
262
+ repainting_start: float,
263
+ repainting_end: Optional[float],
264
+ instruction: str,
265
+ audio_cover_strength: float,
266
+ task_type: str,
267
+ use_adg: bool,
268
+ cfg_interval_start: float,
269
+ cfg_interval_end: float,
270
+ audio_format: str,
271
+ reference_audio=None,
272
+ src_audio=None,
273
+ batch_index: int = 0,
274
+ ) -> str:
275
+ """
276
+ Generate deterministic UUID from generation parameters.
277
+ Same parameters will always generate the same UUID.
278
+
279
+ Args:
280
+ captions: Music caption
281
+ lyrics: Lyrics text
282
+ bpm: BPM value
283
+ key_scale: Musical key and scale
284
+ time_signature: Time signature
285
+ vocal_language: Vocal language code
286
+ inference_steps: Number of inference steps
287
+ guidance_scale: Guidance scale
288
+ seed: Random seed
289
+ audio_duration: Audio duration in seconds
290
+ audio_code_string: Audio code string or list
291
+ repainting_start: Repainting start time
292
+ repainting_end: Repainting end time
293
+ instruction: Task instruction
294
+ audio_cover_strength: Audio cover strength
295
+ task_type: Task type
296
+ use_adg: Whether to use ADG
297
+ cfg_interval_start: CFG interval start
298
+ cfg_interval_end: CFG interval end
299
+ audio_format: Audio format
300
+ reference_audio: Reference audio file path
301
+ src_audio: Source audio file path
302
+ batch_index: Index in batch (for audio_code_string list access)
303
+
304
+ Returns:
305
+ UUID string
306
+ """
307
+ params_dict = {
308
+ "captions": captions or "",
309
+ "lyrics": lyrics or "",
310
+ "bpm": bpm,
311
+ "key_scale": key_scale or "",
312
+ "time_signature": time_signature or "",
313
+ "vocal_language": vocal_language or "",
314
+ "inference_steps": inference_steps,
315
+ "guidance_scale": guidance_scale,
316
+ "seed": seed,
317
+ "audio_duration": audio_duration,
318
+ "audio_code_string": audio_code_string if isinstance(audio_code_string, str) else (audio_code_string[batch_index] if isinstance(audio_code_string, list) and batch_index < len(audio_code_string) else ""),
319
+ "repainting_start": repainting_start,
320
+ "repainting_end": repainting_end,
321
+ "instruction": instruction or "",
322
+ "audio_cover_strength": audio_cover_strength,
323
+ "task_type": task_type or "",
324
+ "use_adg": use_adg,
325
+ "cfg_interval_start": cfg_interval_start,
326
+ "cfg_interval_end": cfg_interval_end,
327
+ "audio_format": audio_format or "",
328
+ "reference_audio_hash": get_audio_file_hash(reference_audio),
329
+ "src_audio_hash": get_audio_file_hash(src_audio),
330
+ }
331
+
332
+ params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
333
+ hash_obj = hashlib.sha256(params_json.encode('utf-8'))
334
+ hash_hex = hash_obj.hexdigest()
335
+ uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
336
+ return uuid_str
337
+
338
+
339
+ def generate_uuid_from_audio_data(
340
+ audio_data: Union[torch.Tensor, np.ndarray],
341
+ seed: Optional[int] = None
342
+ ) -> str:
343
+ """
344
+ Generate UUID from audio data (for caching/deduplication)
345
+
346
+ Args:
347
+ audio_data: Audio data
348
+ seed: Optional seed value
349
+
350
+ Returns:
351
+ UUID string
352
+ """
353
+ if isinstance(audio_data, torch.Tensor):
354
+ # Convert to numpy and calculate hash
355
+ audio_np = audio_data.cpu().numpy()
356
+ else:
357
+ audio_np = audio_data
358
+
359
+ # Calculate data hash
360
+ data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
361
+
362
+ if seed is not None:
363
+ combined = f"{data_hash}_{seed}"
364
+ return hashlib.md5(combined.encode()).hexdigest()
365
+
366
+ return data_hash
367
+
368
+
369
+ # Global default instance
370
+ _default_saver = AudioSaver(default_format="flac")
371
+
372
+
373
+ def save_audio(
374
+ audio_data: Union[torch.Tensor, np.ndarray],
375
+ output_path: Union[str, Path],
376
+ sample_rate: int = 48000,
377
+ format: Optional[str] = None,
378
+ channels_first: bool = True,
379
+ ) -> str:
380
+ """
381
+ Convenience function: save audio (using default configuration)
382
+
383
+ Args:
384
+ audio_data: Audio data
385
+ output_path: Output path
386
+ sample_rate: Sample rate
387
+ format: Format (default flac)
388
+ channels_first: Tensor format flag
389
+
390
+ Returns:
391
+ Saved file path
392
+ """
393
+ return _default_saver.save_audio(
394
+ audio_data, output_path, sample_rate, format, channels_first
395
+ )
396
+
acestep/constrained_logits_processor.py CHANGED
@@ -571,6 +571,33 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
571
  if self.debug:
572
  logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
575
  """
576
  Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
@@ -1484,10 +1511,10 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1484
  if self.debug:
1485
  logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
1486
  else:
1487
- # Force EOS token when target codes count is reached
1488
- mask = torch.full_like(scores, float('-inf'))
1489
- mask[:, self.eos_token_id] = 0
1490
- scores = scores + mask
1491
  if self.debug:
1492
  logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
1493
  return self._apply_temperature_scaling(scores)
@@ -1609,20 +1636,15 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1609
  input_ids: torch.LongTensor,
1610
  scores: torch.FloatTensor,
1611
  ) -> torch.FloatTensor:
1612
- """Process a single sequence and return modified scores."""
1613
 
1614
  # Check if we have tokens in queue for user-provided field
1615
  # If so, inject the next token directly
1616
  if self.user_field_token_queue:
1617
- mask = torch.full_like(scores, float('-inf'))
1618
  next_token = self.user_field_token_queue[0]
1619
- mask[0, next_token] = 0
1620
- scores = scores + mask
1621
  return scores
1622
 
1623
- # Create mask (all -inf initially)
1624
- mask = torch.full_like(scores, float('-inf'))
1625
-
1626
  if self.state in self.fixed_strings:
1627
  # Fixed string state: force specific tokens
1628
  fixed_str = self.fixed_strings[self.state]
@@ -1633,28 +1655,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1633
  # This happens when we're about to complete the </think> tag
1634
  if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
1635
  # Check if the next token would complete the fixed string
1636
- # We check if position_in_state + length of next token would complete it
1637
- # Since we don't know which token will be selected, we check if we're close to completion
1638
- # Actually, a better approach: check if this is the last character(s) of the fixed string
1639
  remaining_chars = len(fixed_str) - self.position_in_state
1640
  # If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
1641
  if remaining_chars <= 10:
1642
  # Force EOS token to stop generation
1643
  if self.eos_token_id is not None:
1644
- mask[0, self.eos_token_id] = 0
1645
- scores = scores + mask
1646
  if self.debug:
1647
  logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
1648
  return scores
1649
 
1650
- for t in allowed:
1651
- mask[0, t] = 0
1652
- # Apply mask
1653
- scores = scores + mask
1654
-
1655
- # Update position tracking
1656
- # We need to check if the selected token completes the fixed string
1657
- # This will be done in update_state() after token selection
1658
  else:
1659
  # Position exceeds string, move to next state
1660
  # If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
@@ -1662,8 +1674,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1662
  if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
1663
  # Force EOS token to stop generation
1664
  if self.eos_token_id is not None:
1665
- mask[0, self.eos_token_id] = 0
1666
- scores = scores + mask
1667
  if self.debug:
1668
  logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
1669
  return scores
@@ -1676,7 +1687,9 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1676
  if self.debug:
1677
  logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
1678
  return scores
1679
- return self._process_single_sequence(input_ids, torch.zeros_like(scores))
 
 
1680
 
1681
  elif self.state == FSMState.BPM_VALUE:
1682
  # Check if field is user-provided and we haven't started injecting yet
@@ -1690,22 +1703,18 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1690
  self.user_field_token_queue = value_tokens
1691
  self.current_user_field = "bpm"
1692
  # Inject first token
1693
- mask[0, value_tokens[0]] = 0
1694
- scores = scores + mask
1695
  return scores
1696
 
1697
  # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
1698
  allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
1699
- for t in allowed:
1700
- mask[0, t] = 0
1701
 
1702
  # Also allow newline if current token sequence prefix allows it
1703
- # Check if current token sequence is in prefix tree and allows newline
1704
  token_prefix = tuple(self.accumulated_token_ids)
1705
  if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
1706
- mask[0, self.newline_token] = 0
1707
 
1708
- scores = scores + mask
1709
 
1710
  elif self.state == FSMState.CAPTION_VALUE:
1711
  # Caption field generation with YAML format support:
@@ -1724,8 +1733,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1724
  self.user_field_token_queue = value_tokens
1725
  self.current_user_field = "caption"
1726
  # Inject first token
1727
- mask[0, value_tokens[0]] = 0
1728
- scores = scores + mask
1729
  return scores
1730
 
1731
  # Check if we should transition after a newline (non-indented line = new field)
@@ -1757,7 +1765,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1757
  # The field name detection will happen in update_state()
1758
  return scores
1759
 
1760
- # Block backticks (code blocks)
1761
  if self.backtick_token is not None:
1762
  scores[0, self.backtick_token] = float('-inf')
1763
 
@@ -1773,8 +1781,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1773
  if self.caption_token_count >= 512:
1774
  # Force end by only allowing newline
1775
  if self.newline_token is not None:
1776
- mask[0, self.newline_token] = 0
1777
- scores = scores + mask
1778
  return scores
1779
 
1780
  # Allow natural generation (with blocked audio codes and backticks)
@@ -1791,8 +1798,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1791
  self.user_field_token_queue = value_tokens
1792
  self.current_user_field = "duration"
1793
  # Inject first token
1794
- mask[0, value_tokens[0]] = 0
1795
- scores = scores + mask
1796
  return scores
1797
 
1798
  # If target_duration is set, force generate that exact value
@@ -1804,26 +1810,22 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1804
  # Force the next digit
1805
  next_digit = int(target_str[current_pos])
1806
  if next_digit in self.digit_tokens:
1807
- mask[0, self.digit_tokens[next_digit]] = 0
1808
  else:
1809
  # All digits generated, force newline
1810
  if self.newline_token:
1811
- mask[0, self.newline_token] = 0
1812
-
1813
- scores = scores + mask
1814
  else:
1815
  # Normal duration generation with range constraint
1816
  # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
1817
  allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
1818
- for t in allowed:
1819
- mask[0, t] = 0
1820
 
1821
  # Also allow newline if current token sequence prefix allows it
1822
  token_prefix = tuple(self.accumulated_token_ids)
1823
  if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
1824
- mask[0, self.newline_token] = 0
1825
 
1826
- scores = scores + mask
1827
 
1828
  elif self.state == FSMState.GENRES_VALUE:
1829
  # Check if field is user-provided and we haven't started injecting yet
@@ -1836,8 +1838,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1836
  self.user_field_token_queue = value_tokens
1837
  self.current_user_field = "genres"
1838
  # Inject first token
1839
- mask[0, value_tokens[0]] = 0
1840
- scores = scores + mask
1841
  return scores
1842
 
1843
  # Try to hot-reload genres vocab if file has changed
@@ -1848,24 +1849,20 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1848
 
1849
  if allowed:
1850
  # Use vocabulary-constrained decoding
1851
- for t in allowed:
1852
- mask[0, t] = 0
1853
- scores = scores + mask
1854
  elif self.genres_vocab:
1855
  # Vocab is loaded but no valid continuation found
1856
  # Force newline to end the field
1857
  if self.newline_token:
1858
- mask[0, self.newline_token] = 0
1859
  if self.debug:
1860
  logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
1861
- scores = scores + mask
1862
  else:
1863
  # Fallback: no vocab loaded, use probability-based ending
1864
  if self._should_end_text_field(scores):
1865
  if self.newline_token:
1866
- mask[0, self.newline_token] = 0
1867
  self._transition_to_next_state()
1868
- scores = scores + mask
1869
  else:
1870
  # Allow any token except newline if we don't have content yet
1871
  if not self.accumulated_value.strip():
@@ -1884,8 +1881,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1884
  self.user_field_token_queue = value_tokens
1885
  self.current_user_field = "keyscale"
1886
  # Inject first token
1887
- mask[0, value_tokens[0]] = 0
1888
- scores = scores + mask
1889
  return scores
1890
 
1891
  # Check if current token sequence is complete (allows newline)
@@ -1893,21 +1889,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1893
  if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
1894
  # Complete keyscale, allow newline
1895
  if self.newline_token:
1896
- mask[0, self.newline_token] = 0
1897
- scores = scores + mask
1898
  else:
1899
  # Not complete, allow valid continuation tokens
1900
  allowed = self._get_allowed_keyscale_tokens()
1901
  if allowed:
1902
- for t in allowed:
1903
- mask[0, t] = 0
1904
- scores = scores + mask
1905
  else:
1906
  # No valid tokens found - force newline to end field
1907
  # This handles edge cases where keyscale format is unexpected
1908
  if self.newline_token:
1909
- mask[0, self.newline_token] = 0
1910
- scores = scores + mask
1911
 
1912
  elif self.state == FSMState.LANGUAGE_VALUE:
1913
  # Language field: Use top-1 probability language (greedy selection)
@@ -1925,8 +1917,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1925
  self.user_field_token_queue = value_tokens
1926
  self.current_user_field = "language"
1927
  # Inject first token
1928
- mask[0, value_tokens[0]] = 0
1929
- scores = scores + mask
1930
  return scores
1931
 
1932
  # If we haven't started generating language yet (empty accumulated_token_ids),
@@ -1938,19 +1929,17 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1938
  candidate_tokens = list(self.language_prefix_tree[empty_prefix])
1939
 
1940
  if candidate_tokens:
1941
- # Find the token with highest probability (top-1)
1942
- # Create a mask that blocks all tokens except candidates
1943
- temp_mask = torch.full_like(scores, float('-inf'))
1944
- for t in candidate_tokens:
1945
- temp_mask[0, t] = 0
1946
- temp_scores = scores + temp_mask
1947
 
1948
  # Get the highest probability token among candidates
1949
- top_token_id = torch.argmax(temp_scores[0]).item()
 
1950
 
1951
- # Only allow this top-1 token, block all others (including other language tokens)
1952
- mask[0, top_token_id] = 0
1953
- scores = scores + mask
1954
 
1955
  if self.debug:
1956
  top_token_text = self.tokenizer.decode([top_token_id])
@@ -1958,13 +1947,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1958
  else:
1959
  # No valid first tokens found - force newline
1960
  if self.newline_token:
1961
- mask[0, self.newline_token] = 0
1962
- scores = scores + mask
1963
  else:
1964
  # Empty prefix not in tree - force newline
1965
  if self.newline_token:
1966
- mask[0, self.newline_token] = 0
1967
- scores = scores + mask
1968
  else:
1969
  # We've started generating a language, continue with prefix tree constraints
1970
  # Check if current token sequence is complete (allows newline)
@@ -1972,20 +1959,16 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1972
  if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
1973
  # Complete language, allow newline
1974
  if self.newline_token:
1975
- mask[0, self.newline_token] = 0
1976
- scores = scores + mask
1977
  else:
1978
  # Not complete, allow valid continuation tokens
1979
  allowed = self._get_allowed_language_tokens()
1980
  if allowed:
1981
- for t in allowed:
1982
- mask[0, t] = 0
1983
- scores = scores + mask
1984
  else:
1985
  # No valid tokens found - force newline to end field
1986
  if self.newline_token:
1987
- mask[0, self.newline_token] = 0
1988
- scores = scores + mask
1989
 
1990
  elif self.state == FSMState.TIMESIG_VALUE:
1991
  # Check if field is user-provided and we haven't started injecting yet
@@ -1998,8 +1981,7 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
1998
  self.user_field_token_queue = value_tokens
1999
  self.current_user_field = "timesignature"
2000
  # Inject first token
2001
- mask[0, value_tokens[0]] = 0
2002
- scores = scores + mask
2003
  return scores
2004
 
2005
  # Check if current token sequence is complete (allows newline)
@@ -2007,14 +1989,11 @@ class MetadataConstrainedLogitsProcessor(LogitsProcessor):
2007
  if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
2008
  # Complete value, allow newline
2009
  if self.newline_token:
2010
- mask[0, self.newline_token] = 0
2011
- scores = scores + mask
2012
  else:
2013
  # Not complete, allow valid continuation tokens
2014
  allowed = self._get_allowed_timesig_tokens()
2015
- for t in allowed:
2016
- mask[0, t] = 0
2017
- scores = scores + mask
2018
 
2019
  return scores
2020
 
 
571
  if self.debug:
572
  logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens")
573
 
574
+ def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None:
575
+ """
576
+ Apply whitelist constraint inplace: only allow specified tokens, block all others.
577
+
578
+ This is more efficient than creating a mask tensor because:
579
+ 1. No memory allocation for mask
580
+ 2. No tensor addition operation
581
+
582
+ Args:
583
+ scores: [1, vocab_size] scores tensor to modify inplace
584
+ allowed_tokens: List of token IDs to allow (all others will be set to -inf)
585
+ """
586
+ if not allowed_tokens:
587
+ # No tokens allowed, set all to -inf
588
+ scores.fill_(float('-inf'))
589
+ return
590
+
591
+ # Save the original values of allowed tokens
592
+ allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long)
593
+ saved_values = scores[0, allowed_indices].clone()
594
+
595
+ # Set all scores to -inf
596
+ scores.fill_(float('-inf'))
597
+
598
+ # Restore allowed token values
599
+ scores[0, allowed_indices] = saved_values
600
+
601
  def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]:
602
  """
603
  Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization.
 
1511
  if self.debug:
1512
  logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS")
1513
  else:
1514
+ # Force EOS token when target codes count is reached - inplace
1515
+ eos_scores = scores[:, self.eos_token_id].clone()
1516
+ scores.fill_(float('-inf'))
1517
+ scores[:, self.eos_token_id] = eos_scores
1518
  if self.debug:
1519
  logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS")
1520
  return self._apply_temperature_scaling(scores)
 
1636
  input_ids: torch.LongTensor,
1637
  scores: torch.FloatTensor,
1638
  ) -> torch.FloatTensor:
1639
+ """Process a single sequence and return modified scores (inplace when possible)."""
1640
 
1641
  # Check if we have tokens in queue for user-provided field
1642
  # If so, inject the next token directly
1643
  if self.user_field_token_queue:
 
1644
  next_token = self.user_field_token_queue[0]
1645
+ self._apply_whitelist_inplace(scores, [next_token])
 
1646
  return scores
1647
 
 
 
 
1648
  if self.state in self.fixed_strings:
1649
  # Fixed string state: force specific tokens
1650
  fixed_str = self.fixed_strings[self.state]
 
1655
  # This happens when we're about to complete the </think> tag
1656
  if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
1657
  # Check if the next token would complete the fixed string
 
 
 
1658
  remaining_chars = len(fixed_str) - self.position_in_state
1659
  # If remaining is small (<= 10 chars, which is typically 1-2 tokens), force EOS
1660
  if remaining_chars <= 10:
1661
  # Force EOS token to stop generation
1662
  if self.eos_token_id is not None:
1663
+ self._apply_whitelist_inplace(scores, [self.eos_token_id])
 
1664
  if self.debug:
1665
  logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)")
1666
  return scores
1667
 
1668
+ # Apply whitelist constraint inplace
1669
+ self._apply_whitelist_inplace(scores, allowed)
 
 
 
 
 
 
1670
  else:
1671
  # Position exceeds string, move to next state
1672
  # If stop_at_reasoning is True and we're transitioning from THINK_END_TAG,
 
1674
  if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning:
1675
  # Force EOS token to stop generation
1676
  if self.eos_token_id is not None:
1677
+ self._apply_whitelist_inplace(scores, [self.eos_token_id])
 
1678
  if self.debug:
1679
  logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag")
1680
  return scores
 
1687
  if self.debug:
1688
  logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion")
1689
  return scores
1690
+ # For recursion, reset scores to zero (no constraints from previous state)
1691
+ scores.zero_()
1692
+ return self._process_single_sequence(input_ids, scores)
1693
 
1694
  elif self.state == FSMState.BPM_VALUE:
1695
  # Check if field is user-provided and we haven't started injecting yet
 
1703
  self.user_field_token_queue = value_tokens
1704
  self.current_user_field = "bpm"
1705
  # Inject first token
1706
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1707
  return scores
1708
 
1709
  # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "120")
1710
  allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree)
 
 
1711
 
1712
  # Also allow newline if current token sequence prefix allows it
 
1713
  token_prefix = tuple(self.accumulated_token_ids)
1714
  if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]:
1715
+ allowed = allowed + [self.newline_token]
1716
 
1717
+ self._apply_whitelist_inplace(scores, allowed)
1718
 
1719
  elif self.state == FSMState.CAPTION_VALUE:
1720
  # Caption field generation with YAML format support:
 
1733
  self.user_field_token_queue = value_tokens
1734
  self.current_user_field = "caption"
1735
  # Inject first token
1736
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1737
  return scores
1738
 
1739
  # Check if we should transition after a newline (non-indented line = new field)
 
1765
  # The field name detection will happen in update_state()
1766
  return scores
1767
 
1768
+ # Block backticks (code blocks) - inplace
1769
  if self.backtick_token is not None:
1770
  scores[0, self.backtick_token] = float('-inf')
1771
 
 
1781
  if self.caption_token_count >= 512:
1782
  # Force end by only allowing newline
1783
  if self.newline_token is not None:
1784
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1785
  return scores
1786
 
1787
  # Allow natural generation (with blocked audio codes and backticks)
 
1798
  self.user_field_token_queue = value_tokens
1799
  self.current_user_field = "duration"
1800
  # Inject first token
1801
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1802
  return scores
1803
 
1804
  # If target_duration is set, force generate that exact value
 
1810
  # Force the next digit
1811
  next_digit = int(target_str[current_pos])
1812
  if next_digit in self.digit_tokens:
1813
+ self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]])
1814
  else:
1815
  # All digits generated, force newline
1816
  if self.newline_token:
1817
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
 
1818
  else:
1819
  # Normal duration generation with range constraint
1820
  # Allow valid numeric tokens using prefix tree (supports multi-digit tokens like "60", "120")
1821
  allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree)
 
 
1822
 
1823
  # Also allow newline if current token sequence prefix allows it
1824
  token_prefix = tuple(self.accumulated_token_ids)
1825
  if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]:
1826
+ allowed = allowed + [self.newline_token]
1827
 
1828
+ self._apply_whitelist_inplace(scores, allowed)
1829
 
1830
  elif self.state == FSMState.GENRES_VALUE:
1831
  # Check if field is user-provided and we haven't started injecting yet
 
1838
  self.user_field_token_queue = value_tokens
1839
  self.current_user_field = "genres"
1840
  # Inject first token
1841
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1842
  return scores
1843
 
1844
  # Try to hot-reload genres vocab if file has changed
 
1849
 
1850
  if allowed:
1851
  # Use vocabulary-constrained decoding
1852
+ self._apply_whitelist_inplace(scores, allowed)
 
 
1853
  elif self.genres_vocab:
1854
  # Vocab is loaded but no valid continuation found
1855
  # Force newline to end the field
1856
  if self.newline_token:
 
1857
  if self.debug:
1858
  logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline")
1859
+ self._apply_whitelist_inplace(scores, [self.newline_token])
1860
  else:
1861
  # Fallback: no vocab loaded, use probability-based ending
1862
  if self._should_end_text_field(scores):
1863
  if self.newline_token:
1864
+ self._apply_whitelist_inplace(scores, [self.newline_token])
1865
  self._transition_to_next_state()
 
1866
  else:
1867
  # Allow any token except newline if we don't have content yet
1868
  if not self.accumulated_value.strip():
 
1881
  self.user_field_token_queue = value_tokens
1882
  self.current_user_field = "keyscale"
1883
  # Inject first token
1884
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1885
  return scores
1886
 
1887
  # Check if current token sequence is complete (allows newline)
 
1889
  if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]:
1890
  # Complete keyscale, allow newline
1891
  if self.newline_token:
1892
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1893
  else:
1894
  # Not complete, allow valid continuation tokens
1895
  allowed = self._get_allowed_keyscale_tokens()
1896
  if allowed:
1897
+ self._apply_whitelist_inplace(scores, allowed)
 
 
1898
  else:
1899
  # No valid tokens found - force newline to end field
1900
  # This handles edge cases where keyscale format is unexpected
1901
  if self.newline_token:
1902
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1903
 
1904
  elif self.state == FSMState.LANGUAGE_VALUE:
1905
  # Language field: Use top-1 probability language (greedy selection)
 
1917
  self.user_field_token_queue = value_tokens
1918
  self.current_user_field = "language"
1919
  # Inject first token
1920
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1921
  return scores
1922
 
1923
  # If we haven't started generating language yet (empty accumulated_token_ids),
 
1929
  candidate_tokens = list(self.language_prefix_tree[empty_prefix])
1930
 
1931
  if candidate_tokens:
1932
+ # Find the token with highest probability (top-1) among candidates
1933
+ # Use tensor indexing to get scores of candidate tokens directly
1934
+ candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long)
1935
+ candidate_scores = scores[0, candidate_indices]
 
 
1936
 
1937
  # Get the highest probability token among candidates
1938
+ best_idx = torch.argmax(candidate_scores).item()
1939
+ top_token_id = candidate_tokens[best_idx]
1940
 
1941
+ # Only allow this top-1 token, block all others
1942
+ self._apply_whitelist_inplace(scores, [top_token_id])
 
1943
 
1944
  if self.debug:
1945
  top_token_text = self.tokenizer.decode([top_token_id])
 
1947
  else:
1948
  # No valid first tokens found - force newline
1949
  if self.newline_token:
1950
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1951
  else:
1952
  # Empty prefix not in tree - force newline
1953
  if self.newline_token:
1954
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1955
  else:
1956
  # We've started generating a language, continue with prefix tree constraints
1957
  # Check if current token sequence is complete (allows newline)
 
1959
  if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]:
1960
  # Complete language, allow newline
1961
  if self.newline_token:
1962
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1963
  else:
1964
  # Not complete, allow valid continuation tokens
1965
  allowed = self._get_allowed_language_tokens()
1966
  if allowed:
1967
+ self._apply_whitelist_inplace(scores, allowed)
 
 
1968
  else:
1969
  # No valid tokens found - force newline to end field
1970
  if self.newline_token:
1971
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1972
 
1973
  elif self.state == FSMState.TIMESIG_VALUE:
1974
  # Check if field is user-provided and we haven't started injecting yet
 
1981
  self.user_field_token_queue = value_tokens
1982
  self.current_user_field = "timesignature"
1983
  # Inject first token
1984
+ self._apply_whitelist_inplace(scores, [value_tokens[0]])
 
1985
  return scores
1986
 
1987
  # Check if current token sequence is complete (allows newline)
 
1989
  if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]:
1990
  # Complete value, allow newline
1991
  if self.newline_token:
1992
+ self._apply_whitelist_inplace(scores, [self.newline_token])
 
1993
  else:
1994
  # Not complete, allow valid continuation tokens
1995
  allowed = self._get_allowed_timesig_tokens()
1996
+ self._apply_whitelist_inplace(scores, allowed)
 
 
1997
 
1998
  return scores
1999
 
acestep/gradio_ui/event.py DELETED
The diff for this file is too large to render. See raw diff
 
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -332,10 +332,9 @@ def generate_with_progress(
332
  logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
333
 
334
  # Generate batch
335
- metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
336
  caption=captions or "",
337
  lyrics=lyrics or "",
338
- batch_size=chunk_size,
339
  infer_type="llm_dit",
340
  temperature=lm_temperature,
341
  cfg_scale=lm_cfg_scale,
@@ -347,6 +346,7 @@ def generate_with_progress(
347
  use_cot_language=use_cot_language,
348
  is_format_caption=is_format_caption,
349
  constrained_decoding_debug=constrained_decoding_debug,
 
350
  seeds=chunk_seeds,
351
  )
352
 
@@ -474,9 +474,26 @@ def generate_with_progress(
474
  progress=progress
475
  )
476
 
477
- # Extract results
478
- first_audio, second_audio, all_audio_paths, generation_info, status_message, seed_value_for_ui, \
479
- align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2 = result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
  # Extract LM timing from status if available and prepend to generation_info
482
  if status:
 
332
  logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
333
 
334
  # Generate batch
335
+ metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition(
336
  caption=captions or "",
337
  lyrics=lyrics or "",
 
338
  infer_type="llm_dit",
339
  temperature=lm_temperature,
340
  cfg_scale=lm_cfg_scale,
 
346
  use_cot_language=use_cot_language,
347
  is_format_caption=is_format_caption,
348
  constrained_decoding_debug=constrained_decoding_debug,
349
+ batch_size=chunk_size,
350
  seeds=chunk_seeds,
351
  )
352
 
 
474
  progress=progress
475
  )
476
 
477
+ # Extract results from new dict structure
478
+ if not isinstance(result, dict):
479
+ # Fallback for old tuple format (should not happen)
480
+ first_audio, second_audio, all_audio_paths, generation_info, status_message, seed_value_for_ui, \
481
+ align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2 = result
482
+ else:
483
+ audios = result.get("audios", [])
484
+ all_audio_paths = [audio.get("path") for audio in audios]
485
+ first_audio = all_audio_paths[0] if len(all_audio_paths) > 0 else None
486
+ second_audio = all_audio_paths[1] if len(all_audio_paths) > 1 else None
487
+ generation_info = result.get("generation_info", "")
488
+ status_message = result.get("status_message", "")
489
+ seed_value_for_ui = result.get("extra_outputs", {}).get("seed_value", "")
490
+ # Legacy alignment fields (no longer used)
491
+ align_score_1 = ""
492
+ align_text_1 = ""
493
+ align_plot_1 = None
494
+ align_score_2 = ""
495
+ align_text_2 = ""
496
+ align_plot_2 = None
497
 
498
  # Extract LM timing from status if available and prepend to generation_info
499
  if status:
acestep/handler.py CHANGED
@@ -10,6 +10,8 @@ import traceback
10
  import re
11
  import random
12
  import uuid
 
 
13
  from contextlib import contextmanager
14
  from typing import Optional, Dict, Any, Tuple, List, Union
15
 
@@ -37,16 +39,12 @@ warnings.filterwarnings("ignore")
37
  class AceStepHandler:
38
  """ACE-Step Business Logic Handler"""
39
 
40
- def __init__(self, save_root = None):
41
  self.model = None
42
  self.config = None
43
  self.device = "cpu"
44
  self.dtype = torch.float32 # Will be set based on device in initialize_service
45
- if save_root is None:
46
- self.temp_dir = tempfile.mkdtemp()
47
- else:
48
- self.temp_dir = save_root
49
-
50
  # VAE for audio encoding/decoding
51
  self.vae = None
52
 
@@ -81,8 +79,7 @@ class AceStepHandler:
81
  def get_available_checkpoints(self) -> str:
82
  """Return project root directory path"""
83
  # Get project root (handler.py is in acestep/, so go up two levels to project root)
84
- current_file = os.path.abspath(__file__)
85
- project_root = os.path.dirname(os.path.dirname(current_file))
86
  # default checkpoints
87
  checkpoint_dir = os.path.join(project_root, "checkpoints")
88
  if os.path.exists(checkpoint_dir):
@@ -93,8 +90,7 @@ class AceStepHandler:
93
  def get_available_acestep_v15_models(self) -> List[str]:
94
  """Scan and return all model directory names starting with 'acestep-v15-'"""
95
  # Get project root
96
- current_file = os.path.abspath(__file__)
97
- project_root = os.path.dirname(os.path.dirname(current_file))
98
  checkpoint_dir = os.path.join(project_root, "checkpoints")
99
 
100
  models = []
@@ -171,8 +167,7 @@ class AceStepHandler:
171
 
172
 
173
  # Auto-detect project root (independent of passed project_root parameter)
174
- current_file = os.path.abspath(__file__)
175
- actual_project_root = os.path.dirname(os.path.dirname(current_file))
176
  checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
177
 
178
  # 1. Load main model
@@ -187,7 +182,7 @@ class AceStepHandler:
187
  attn_implementation = "sdpa"
188
 
189
  try:
190
- logger.info(f"Attempting to load model with attention implementation: {attn_implementation}")
191
  self.model = AutoModel.from_pretrained(
192
  acestep_v15_checkpoint_path,
193
  trust_remote_code=True,
@@ -195,9 +190,9 @@ class AceStepHandler:
195
  dtype="bfloat16"
196
  )
197
  except Exception as e:
198
- logger.warning(f"Failed to load model with {attn_implementation}: {e}")
199
  if attn_implementation == "sdpa":
200
- logger.info("Falling back to eager attention")
201
  attn_implementation = "eager"
202
  self.model = AutoModel.from_pretrained(
203
  acestep_v15_checkpoint_path,
@@ -215,7 +210,7 @@ class AceStepHandler:
215
  else:
216
  # If offload_to_cpu is True, check if we should keep DiT on GPU
217
  if not self.offload_dit_to_cpu:
218
- logger.info(f"Keeping main model on {device} (persistent)")
219
  self.model = self.model.to(device).to(self.dtype)
220
  else:
221
  self.model = self.model.to("cpu").to(self.dtype)
@@ -239,7 +234,7 @@ class AceStepHandler:
239
  raise ValueError(f"Unsupported quantization type: {self.quantization}")
240
 
241
  quantize_(self.model, quant_config)
242
- logger.info(f"DiT quantized with: {self.quantization}")
243
 
244
 
245
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
@@ -260,7 +255,7 @@ class AceStepHandler:
260
  if os.path.exists(vae_checkpoint_path):
261
  self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
262
  # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
263
- vae_dtype = torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
264
  if not self.offload_to_cpu:
265
  self.vae = self.vae.to(device).to(vae_dtype)
266
  else:
@@ -302,6 +297,7 @@ class AceStepHandler:
302
 
303
  except Exception as e:
304
  error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
305
  return error_msg, False
306
 
307
  @contextmanager
@@ -326,7 +322,7 @@ class AceStepHandler:
326
  try:
327
  param = next(model.parameters())
328
  if param.device.type == "cpu":
329
- logger.info(f"Moving {model_name} to {self.device} (persistent)")
330
  model.to(self.device).to(self.dtype)
331
  if hasattr(self, "silence_latent"):
332
  self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
@@ -341,10 +337,10 @@ class AceStepHandler:
341
  return
342
 
343
  # Load to GPU
344
- logger.info(f"Loading {model_name} to {self.device}")
345
  start_time = time.time()
346
  if model_name == "vae":
347
- vae_dtype = torch.bfloat16 if self.device in ["cuda", "xpu"] else self.dtype
348
  model.to(self.device).to(vae_dtype)
349
  else:
350
  model.to(self.device).to(self.dtype)
@@ -354,13 +350,13 @@ class AceStepHandler:
354
 
355
  load_time = time.time() - start_time
356
  self.current_offload_cost += load_time
357
- logger.info(f"Loaded {model_name} to {self.device} in {load_time:.4f}s")
358
 
359
  try:
360
  yield
361
  finally:
362
  # Offload to CPU
363
- logger.info(f"Offloading {model_name} to CPU")
364
  start_time = time.time()
365
  model.to("cpu")
366
 
@@ -370,7 +366,7 @@ class AceStepHandler:
370
  torch.cuda.empty_cache()
371
  offload_time = time.time() - start_time
372
  self.current_offload_cost += offload_time
373
- logger.info(f"Offloaded {model_name} to CPU in {offload_time:.4f}s")
374
 
375
  def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
376
  """Process target audio"""
@@ -386,23 +382,12 @@ class AceStepHandler:
386
  else:
387
  audio = torch.from_numpy(audio_np.T)
388
 
389
- if audio.shape[0] == 1:
390
- audio = torch.cat([audio, audio], dim=0)
391
-
392
- audio = audio[:2]
393
-
394
- # Resample if needed
395
- if sr != 48000:
396
- import torch.nn.functional as F
397
- ratio = 48000 / sr
398
- new_length = int(audio.shape[-1] * ratio)
399
- audio = F.interpolate(audio.unsqueeze(0), size=new_length, mode='linear', align_corners=False).squeeze(0)
400
-
401
- audio = torch.clamp(audio, -1.0, 1.0)
402
 
403
  return audio
404
  except Exception as e:
405
- logger.error(f"Error processing target audio: {e}")
406
  return None
407
 
408
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
@@ -411,7 +396,8 @@ class AceStepHandler:
411
  return []
412
  try:
413
  return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
414
- except Exception:
 
415
  return []
416
 
417
  def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
@@ -538,9 +524,7 @@ class AceStepHandler:
538
  )
539
  """
540
  # Align instruction formatting with _prepare_batch
541
- final_instruction = instruction or DEFAULT_DIT_INSTRUCTION
542
- if not final_instruction.endswith(":"):
543
- final_instruction = final_instruction + ":"
544
 
545
  # Extract caption and language from metas if available (from LM CoT output)
546
  # Fallback to user-provided values if not in metas
@@ -571,7 +555,7 @@ class AceStepHandler:
571
 
572
  parsed_meta = self._parse_metas([metas])[0]
573
  caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
574
- lyrics_input = f"# Languages\n{actual_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
575
  return caption_input, lyrics_input
576
 
577
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -614,7 +598,7 @@ class AceStepHandler:
614
  return match.group(1).strip()
615
  return caption
616
  except Exception as e:
617
- logger.error(f"Error extracting caption: {e}")
618
  return caption
619
 
620
  def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
@@ -638,7 +622,8 @@ class AceStepHandler:
638
  else:
639
  try:
640
  seed_list.append(int(float(s)))
641
- except (ValueError, TypeError):
 
642
  seed_list.append(-1)
643
  elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
644
  # If seed is None or negative, use -1 for all items
@@ -679,7 +664,176 @@ class AceStepHandler:
679
  return actual_seed_list, seed_value_for_ui
680
 
681
  def prepare_metadata(self, bpm, key_scale, time_signature):
682
- # Build metadata dict - use "N/A" as default for empty fields
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  metadata_dict = {}
684
  if bpm:
685
  metadata_dict["bpm"] = bpm
@@ -695,10 +849,12 @@ class AceStepHandler:
695
  metadata_dict["timesignature"] = time_signature
696
  else:
697
  metadata_dict["timesignature"] = "N/A"
 
 
 
 
 
698
  return metadata_dict
699
-
700
- def is_silence(self, audio):
701
- return torch.all(audio.abs() < 1e-6)
702
 
703
  def generate_instruction(
704
  self,
@@ -745,23 +901,12 @@ class AceStepHandler:
745
  # Load audio file
746
  audio, sr = torchaudio.load(audio_file)
747
 
748
- logger.info(f"Reference audio shape: {audio.shape}")
749
- logger.info(f"Reference audio sample rate: {sr}")
750
- logger.info(f"Reference audio duration: {audio.shape[-1] / 48000.0} seconds")
751
-
752
- # Convert to stereo (duplicate channel if mono)
753
- if audio.shape[0] == 1:
754
- audio = torch.cat([audio, audio], dim=0)
755
 
756
- # Keep only first 2 channels
757
- audio = audio[:2]
758
-
759
- # Resample to 48kHz if needed
760
- if sr != 48000:
761
- audio = torchaudio.transforms.Resample(sr, 48000)(audio)
762
-
763
- # Clamp values to [-1.0, 1.0]
764
- audio = torch.clamp(audio, -1.0, 1.0)
765
 
766
  is_silence = self.is_silence(audio)
767
  if is_silence:
@@ -800,7 +945,7 @@ class AceStepHandler:
800
  return audio
801
 
802
  except Exception as e:
803
- logger.error(f"Error processing reference audio: {e}")
804
  return None
805
 
806
  def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
@@ -811,24 +956,13 @@ class AceStepHandler:
811
  # Load audio file
812
  audio, sr = torchaudio.load(audio_file)
813
 
814
- # Convert to stereo (duplicate channel if mono)
815
- if audio.shape[0] == 1:
816
- audio = torch.cat([audio, audio], dim=0)
817
-
818
- # Keep only first 2 channels
819
- audio = audio[:2]
820
-
821
- # Resample to 48kHz if needed
822
- if sr != 48000:
823
- audio = torchaudio.transforms.Resample(sr, 48000)(audio)
824
-
825
- # Clamp values to [-1.0, 1.0]
826
- audio = torch.clamp(audio, -1.0, 1.0)
827
 
828
  return audio
829
 
830
  except Exception as e:
831
- logger.error(f"Error processing target audio: {e}")
832
  return None
833
 
834
  def convert_src_audio_to_codes(self, audio_file) -> str:
@@ -856,19 +990,12 @@ class AceStepHandler:
856
  # Encode audio to latents using VAE
857
  with torch.no_grad():
858
  with self._load_model_context("vae"):
859
- # Prepare audio for VAE: [channels, samples] -> [1, channels, samples]
860
- vae_input = processed_audio.unsqueeze(0).to(self.device).to(self.vae.dtype)
861
-
862
  # Check if audio is silence
863
- if self.is_silence(vae_input):
864
  return "❌ Audio file appears to be silent"
865
 
866
- # Encode to latents
867
- latents = self.vae.encode(vae_input).latent_dist.sample()
868
- # Cast back to model dtype
869
- latents = latents.to(self.dtype)
870
- # Transpose: [1, d, T] -> [1, T, d] -> [T, d]
871
- latents = latents.squeeze(0).transpose(0, 1) # [T, d]
872
 
873
  # Create attention mask for latents
874
  attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
@@ -893,7 +1020,7 @@ class AceStepHandler:
893
 
894
  except Exception as e:
895
  error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
896
- logger.error(error_msg)
897
  return error_msg
898
 
899
  def prepare_batch_data(
@@ -922,26 +1049,7 @@ class AceStepHandler:
922
  calculated_duration = audio_duration
923
 
924
  # Build metadata dict - use "N/A" as default for empty fields
925
- metadata_dict = {}
926
- if bpm:
927
- metadata_dict["bpm"] = bpm
928
- else:
929
- metadata_dict["bpm"] = "N/A"
930
-
931
- if key_scale.strip():
932
- metadata_dict["keyscale"] = key_scale
933
- else:
934
- metadata_dict["keyscale"] = "N/A"
935
-
936
- if time_signature.strip() and time_signature != "N/A" and time_signature:
937
- metadata_dict["timesignature"] = time_signature
938
- else:
939
- metadata_dict["timesignature"] = "N/A"
940
-
941
- # Add duration to metadata if available (inference service format: "30 seconds")
942
- if calculated_duration is not None:
943
- metadata_dict["duration"] = f"{int(calculated_duration)} seconds"
944
- # If duration not set, inference service will use default (30 seconds)
945
 
946
  # Format metadata - inference service accepts dict and will convert to string
947
  # Create a copy for each batch item (in case we modify it)
@@ -977,7 +1085,7 @@ class AceStepHandler:
977
  target_wavs = torch.zeros(2, frames)
978
  return target_wavs
979
  except Exception as e:
980
- logger.error(f"Error creating target audio: {e}")
981
  # Fallback to 30 seconds if error
982
  return torch.zeros(2, 30 * 48000)
983
 
@@ -1158,16 +1266,8 @@ class AceStepHandler:
1158
  """
1159
  batch_size = len(captions)
1160
 
1161
- # Ensure audio_code_hints is a list of the correct length
1162
- if audio_code_hints is None:
1163
- audio_code_hints = [None] * batch_size
1164
- elif len(audio_code_hints) != batch_size:
1165
- if len(audio_code_hints) == 1:
1166
- audio_code_hints = audio_code_hints * batch_size
1167
- else:
1168
- audio_code_hints = audio_code_hints[:batch_size]
1169
- while len(audio_code_hints) < batch_size:
1170
- audio_code_hints.append(None)
1171
 
1172
  for ii, refer_audio_list in enumerate(refer_audios):
1173
  if isinstance(refer_audio_list, list):
@@ -1179,17 +1279,6 @@ class AceStepHandler:
1179
  if vocal_languages is None:
1180
  vocal_languages = self._create_fallback_vocal_languages(batch_size)
1181
 
1182
- # Normalize audio_code_hints to batch list
1183
- if audio_code_hints is None:
1184
- audio_code_hints = [None] * batch_size
1185
- elif not isinstance(audio_code_hints, list):
1186
- audio_code_hints = [audio_code_hints] * batch_size
1187
- elif len(audio_code_hints) == 1 and batch_size > 1:
1188
- audio_code_hints = audio_code_hints * batch_size
1189
- else:
1190
- audio_code_hints = (audio_code_hints + [None] * batch_size)[:batch_size]
1191
- audio_code_hints = [hint if isinstance(hint, str) and hint.strip() else None for hint in audio_code_hints]
1192
-
1193
  # Parse metas with fallbacks
1194
  parsed_metas = self._parse_metas(metas)
1195
 
@@ -1223,13 +1312,9 @@ class AceStepHandler:
1223
  expected_latent_length = current_wav.shape[-1] // 1920
1224
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1225
  else:
1226
- # Ensure input is in VAE's dtype
1227
  logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
1228
- vae_input = current_wav.to(self.device).to(self.vae.dtype)
1229
- target_latent = self.vae.encode(vae_input).latent_dist.sample()
1230
- # Cast back to model dtype
1231
- target_latent = target_latent.to(self.dtype)
1232
- target_latent = target_latent.squeeze(0).transpose(0, 1)
1233
  target_latents_list.append(target_latent)
1234
  latent_lengths.append(target_latent.shape[0])
1235
 
@@ -1268,18 +1353,7 @@ class AceStepHandler:
1268
 
1269
  # Process instructions early so we can use them for task type detection
1270
  # Use custom instructions if provided, otherwise use default
1271
- if instructions is None:
1272
- instructions = [DEFAULT_DIT_INSTRUCTION] * batch_size
1273
-
1274
- # Ensure instructions list has the same length as batch_size
1275
- if len(instructions) != batch_size:
1276
- if len(instructions) == 1:
1277
- instructions = instructions * batch_size
1278
- else:
1279
- # Pad or truncate to match batch_size
1280
- instructions = instructions[:batch_size]
1281
- while len(instructions) < batch_size:
1282
- instructions.append(DEFAULT_DIT_INSTRUCTION)
1283
 
1284
  # Generate chunk_masks and spans based on repainting parameters
1285
  # Also determine if this is a cover task (target audio provided without repainting)
@@ -1428,6 +1502,10 @@ class AceStepHandler:
1428
  else:
1429
  precomputed_lm_hints_25Hz = None
1430
 
 
 
 
 
1431
  # Format text_inputs
1432
  text_inputs = []
1433
  text_token_idss = []
@@ -1437,26 +1515,10 @@ class AceStepHandler:
1437
 
1438
  for i in range(batch_size):
1439
  # Use custom instruction for this batch item
1440
- instruction = instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION
1441
- # Ensure instruction ends with ":"
1442
- if not instruction.endswith(":"):
1443
- instruction = instruction + ":"
1444
-
1445
- # Extract caption and language from metas if available (from LM CoT output)
1446
- # Fallback to user-provided values if not in metas
1447
- actual_caption = captions[i]
1448
- actual_language = vocal_languages[i]
1449
-
1450
- # Check if metas contains caption/language from LM CoT
1451
- if i < len(parsed_metas) and parsed_metas[i]:
1452
- meta_dict = parsed_metas[i]
1453
- if isinstance(meta_dict, dict):
1454
- # Extract caption from metas if available
1455
- if 'caption' in meta_dict and meta_dict['caption']:
1456
- actual_caption = str(meta_dict['caption'])
1457
- # Extract language from metas if available
1458
- if 'language' in meta_dict and meta_dict['language']:
1459
- actual_language = str(meta_dict['language'])
1460
 
1461
  # Format text prompt with custom instruction (using LM-generated caption if available)
1462
  text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
@@ -1473,7 +1535,7 @@ class AceStepHandler:
1473
  text_attention_mask = text_inputs_dict.attention_mask[0].bool()
1474
 
1475
  # Format and tokenize lyrics (using LM-generated language if available)
1476
- lyrics_text = f"# Languages\n{actual_language}\n\n# Lyric\n{lyrics[i]}<|endoftext|>"
1477
  lyrics_inputs_dict = self.text_tokenizer(
1478
  lyrics_text,
1479
  padding="longest",
@@ -1495,36 +1557,12 @@ class AceStepHandler:
1495
 
1496
  # Pad tokenized sequences
1497
  max_text_length = max(len(seq) for seq in text_token_idss)
1498
- padded_text_token_idss = torch.stack([
1499
- torch.nn.functional.pad(
1500
- seq, (0, max_text_length - len(seq)), 'constant',
1501
- self.text_tokenizer.pad_token_id
1502
- )
1503
- for seq in text_token_idss
1504
- ])
1505
-
1506
- padded_text_attention_masks = torch.stack([
1507
- torch.nn.functional.pad(
1508
- seq, (0, max_text_length - len(seq)), 'constant', 0
1509
- )
1510
- for seq in text_attention_masks
1511
- ])
1512
 
1513
  max_lyric_length = max(len(seq) for seq in lyric_token_idss)
1514
- padded_lyric_token_idss = torch.stack([
1515
- torch.nn.functional.pad(
1516
- seq, (0, max_lyric_length - len(seq)), 'constant',
1517
- self.text_tokenizer.pad_token_id
1518
- )
1519
- for seq in lyric_token_idss
1520
- ])
1521
-
1522
- padded_lyric_attention_masks = torch.stack([
1523
- torch.nn.functional.pad(
1524
- seq, (0, max_lyric_length - len(seq)), 'constant', 0
1525
- )
1526
- for seq in lyric_attention_masks
1527
- ])
1528
 
1529
  padded_non_cover_text_input_ids = None
1530
  padded_non_cover_text_attention_masks = None
@@ -1533,14 +1571,10 @@ class AceStepHandler:
1533
  non_cover_text_attention_masks = []
1534
  for i in range(batch_size):
1535
  # Use custom instruction for this batch item
1536
- instruction = DEFAULT_DIT_INSTRUCTION
1537
 
1538
  # Extract caption from metas if available (from LM CoT output)
1539
- actual_caption = captions[i]
1540
- if i < len(parsed_metas) and parsed_metas[i]:
1541
- meta_dict = parsed_metas[i]
1542
- if isinstance(meta_dict, dict) and 'caption' in meta_dict and meta_dict['caption']:
1543
- actual_caption = str(meta_dict['caption'])
1544
 
1545
  # Format text prompt with custom instruction (using LM-generated caption if available)
1546
  text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
@@ -1558,19 +1592,8 @@ class AceStepHandler:
1558
  non_cover_text_input_ids.append(text_token_ids)
1559
  non_cover_text_attention_masks.append(non_cover_text_attention_mask)
1560
 
1561
- padded_non_cover_text_input_ids = torch.stack([
1562
- torch.nn.functional.pad(
1563
- seq, (0, max_text_length - len(seq)), 'constant',
1564
- self.text_tokenizer.pad_token_id
1565
- )
1566
- for seq in non_cover_text_input_ids
1567
- ])
1568
- padded_non_cover_text_attention_masks = torch.stack([
1569
- torch.nn.functional.pad(
1570
- seq, (0, max_text_length - len(seq)), 'constant', 0
1571
- )
1572
- for seq in non_cover_text_attention_masks
1573
- ])
1574
 
1575
  if audio_cover_strength < 1.0:
1576
  assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
@@ -1804,7 +1827,7 @@ class AceStepHandler:
1804
  if self.config.is_turbo:
1805
  # Limit inference steps to maximum 8
1806
  if infer_steps > 8:
1807
- logger.warning(f"dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
1808
  infer_steps = 8
1809
  # CFG parameters are not adjustable for dmd_gan (they will be ignored)
1810
  # Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
@@ -1827,30 +1850,12 @@ class AceStepHandler:
1827
  if isinstance(repainting_end, (int, float)):
1828
  repainting_end = [repainting_end]
1829
 
1830
- # Convert instructions to list
1831
- if isinstance(instructions, str):
1832
- instructions = [instructions]
1833
- elif instructions is None:
1834
- instructions = None
1835
-
1836
- # Convert audio_code_hints to list
1837
- if isinstance(audio_code_hints, str):
1838
- audio_code_hints = [audio_code_hints]
1839
- elif audio_code_hints is None:
1840
- audio_code_hints = None
1841
-
1842
  # Get batch size from captions
1843
  batch_size = len(captions)
1844
 
1845
- # Ensure audio_code_hints matches batch size
1846
- if audio_code_hints is not None:
1847
- if len(audio_code_hints) != batch_size:
1848
- if len(audio_code_hints) == 1:
1849
- audio_code_hints = audio_code_hints * batch_size
1850
- else:
1851
- audio_code_hints = audio_code_hints[:batch_size]
1852
- while len(audio_code_hints) < batch_size:
1853
- audio_code_hints.append(None)
1854
 
1855
  # Convert seed to list format
1856
  if seed is None:
@@ -1947,6 +1952,14 @@ class AceStepHandler:
1947
  logger.info("[service_generate] Generating audio...")
1948
  with self._load_model_context("model"):
1949
  outputs = self.model.generate_audio(**generate_kwargs)
 
 
 
 
 
 
 
 
1950
  return outputs
1951
 
1952
  def tiled_decode(self, latents, chunk_size=512, overlap=64):
@@ -2042,25 +2055,34 @@ class AceStepHandler:
2042
  use_adg: bool = False,
2043
  cfg_interval_start: float = 0.0,
2044
  cfg_interval_end: float = 1.0,
2045
- audio_format: str = "mp3",
2046
- lm_temperature: float = 0.6,
2047
  use_tiled_decode: bool = True,
2048
  progress=None
2049
- ) -> Tuple[Optional[str], Optional[str], List[str], str, str, str, str, str, Optional[Any], str, str, Optional[Any]]:
2050
  """
2051
  Main interface for music generation
2052
 
2053
  Returns:
2054
- (first_audio, second_audio, all_audio_paths, generation_info, status_message,
2055
- seed_value_for_ui, align_score_1, align_text_1, align_plot_1,
2056
- align_score_2, align_text_2, align_plot_2)
 
 
 
 
2057
  """
2058
  if progress is None:
2059
  def progress(*args, **kwargs):
2060
  pass
2061
 
2062
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
2063
- return None, None, [], "", "❌ Model not fully initialized. Please initialize all components first.", "-1", "", "", None, "", "", None
 
 
 
 
 
 
 
2064
 
2065
  def _has_audio_codes(v: Union[str, List[str]]) -> bool:
2066
  if isinstance(v, list):
@@ -2191,8 +2213,8 @@ class AceStepHandler:
2191
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
2192
  time_costs = outputs["time_costs"]
2193
  time_costs["offload_time_cost"] = self.current_offload_cost
2194
- logger.info(f" - pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
2195
- logger.info(f" - time_costs: {time_costs}")
2196
  if progress:
2197
  progress(0.8, desc="Decoding audio...")
2198
  logger.info("[generate_music] Decoding latents with VAE...")
@@ -2221,30 +2243,19 @@ class AceStepHandler:
2221
  # Update offload cost one last time to include VAE offloading
2222
  time_costs["offload_time_cost"] = self.current_offload_cost
2223
 
2224
- logger.info("[generate_music] VAE decode completed. Saving audio files...")
2225
  if progress:
2226
- progress(0.9, desc="Saving audio files...")
2227
 
2228
- # Save audio files using soundfile (supports wav, flac, mp3 via format param)
2229
- audio_format_lower = audio_format.lower() if audio_format else "wav"
2230
- if audio_format_lower not in ["wav", "flac", "mp3"]:
2231
- audio_format_lower = "wav"
2232
 
2233
- saved_files = []
2234
- saved_uuids = [] # Store UUIDs for each file
2235
  for i in range(actual_batch_size):
2236
- # Generate unique UUID for each audio file
2237
- file_uuid = str(uuid.uuid4())
2238
- audio_file = os.path.join(self.temp_dir, f"{file_uuid}.{audio_format_lower}")
2239
- # Convert to numpy: [channels, samples] -> [samples, channels]
2240
- audio_np = pred_wavs[i].cpu().float().numpy().T
2241
- sf.write(audio_file, audio_np, self.sample_rate)
2242
- saved_files.append(audio_file)
2243
- saved_uuids.append(file_uuid)
2244
-
2245
- # Prepare return values
2246
- first_audio = saved_files[0] if len(saved_files) > 0 else None
2247
- second_audio = saved_files[1] if len(saved_files) > 1 else None
2248
 
2249
  # Format time costs if available
2250
  time_costs_str = ""
@@ -2262,34 +2273,55 @@ class AceStepHandler:
2262
 
2263
  **Seeds:** {seed_value_for_ui}
2264
  **Steps:** {inference_steps}
2265
- **Files:** {len(saved_files)} audio(s){time_costs_str}"""
2266
  status_message = f"✅ Generation completed successfully!"
2267
- logger.info(f"[generate_music] Done! Generated {len(saved_files)} audio files.")
2268
-
2269
- # Alignment scores and plots (placeholder for now)
2270
- align_score_1 = ""
2271
- align_text_1 = ""
2272
- align_plot_1 = None
2273
- align_score_2 = ""
2274
- align_text_2 = ""
2275
- align_plot_2 = None
2276
-
2277
- return (
2278
- first_audio,
2279
- second_audio,
2280
- saved_files,
2281
- generation_info,
2282
- status_message,
2283
- seed_value_for_ui,
2284
- align_score_1,
2285
- align_text_1,
2286
- align_plot_1,
2287
- align_score_2,
2288
- align_text_2,
2289
- align_plot_2,
2290
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2291
 
2292
  except Exception as e:
2293
  error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
2294
- return None, None, [], "", error_msg, seed_value_for_ui, "", "", None, "", "", None
2295
-
 
 
 
 
 
 
 
 
10
  import re
11
  import random
12
  import uuid
13
+ import hashlib
14
+ import json
15
  from contextlib import contextmanager
16
  from typing import Optional, Dict, Any, Tuple, List, Union
17
 
 
39
  class AceStepHandler:
40
  """ACE-Step Business Logic Handler"""
41
 
42
+ def __init__(self):
43
  self.model = None
44
  self.config = None
45
  self.device = "cpu"
46
  self.dtype = torch.float32 # Will be set based on device in initialize_service
47
+
 
 
 
 
48
  # VAE for audio encoding/decoding
49
  self.vae = None
50
 
 
79
  def get_available_checkpoints(self) -> str:
80
  """Return project root directory path"""
81
  # Get project root (handler.py is in acestep/, so go up two levels to project root)
82
+ project_root = self._get_project_root()
 
83
  # default checkpoints
84
  checkpoint_dir = os.path.join(project_root, "checkpoints")
85
  if os.path.exists(checkpoint_dir):
 
90
  def get_available_acestep_v15_models(self) -> List[str]:
91
  """Scan and return all model directory names starting with 'acestep-v15-'"""
92
  # Get project root
93
+ project_root = self._get_project_root()
 
94
  checkpoint_dir = os.path.join(project_root, "checkpoints")
95
 
96
  models = []
 
167
 
168
 
169
  # Auto-detect project root (independent of passed project_root parameter)
170
+ actual_project_root = self._get_project_root()
 
171
  checkpoint_dir = os.path.join(actual_project_root, "checkpoints")
172
 
173
  # 1. Load main model
 
182
  attn_implementation = "sdpa"
183
 
184
  try:
185
+ logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}")
186
  self.model = AutoModel.from_pretrained(
187
  acestep_v15_checkpoint_path,
188
  trust_remote_code=True,
 
190
  dtype="bfloat16"
191
  )
192
  except Exception as e:
193
+ logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}")
194
  if attn_implementation == "sdpa":
195
+ logger.info("[initialize_service] Falling back to eager attention")
196
  attn_implementation = "eager"
197
  self.model = AutoModel.from_pretrained(
198
  acestep_v15_checkpoint_path,
 
210
  else:
211
  # If offload_to_cpu is True, check if we should keep DiT on GPU
212
  if not self.offload_dit_to_cpu:
213
+ logger.info(f"[initialize_service] Keeping main model on {device} (persistent)")
214
  self.model = self.model.to(device).to(self.dtype)
215
  else:
216
  self.model = self.model.to("cpu").to(self.dtype)
 
234
  raise ValueError(f"Unsupported quantization type: {self.quantization}")
235
 
236
  quantize_(self.model, quant_config)
237
+ logger.info(f"[initialize_service] DiT quantized with: {self.quantization}")
238
 
239
 
240
  silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt")
 
255
  if os.path.exists(vae_checkpoint_path):
256
  self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path)
257
  # Use bfloat16 for VAE on GPU, otherwise use self.dtype (float32 on CPU)
258
+ vae_dtype = self._get_vae_dtype(device)
259
  if not self.offload_to_cpu:
260
  self.vae = self.vae.to(device).to(vae_dtype)
261
  else:
 
297
 
298
  except Exception as e:
299
  error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
300
+ logger.exception("[initialize_service] Error initializing model")
301
  return error_msg, False
302
 
303
  @contextmanager
 
322
  try:
323
  param = next(model.parameters())
324
  if param.device.type == "cpu":
325
+ logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)")
326
  model.to(self.device).to(self.dtype)
327
  if hasattr(self, "silence_latent"):
328
  self.silence_latent = self.silence_latent.to(self.device).to(self.dtype)
 
337
  return
338
 
339
  # Load to GPU
340
+ logger.info(f"[_load_model_context] Loading {model_name} to {self.device}")
341
  start_time = time.time()
342
  if model_name == "vae":
343
+ vae_dtype = self._get_vae_dtype()
344
  model.to(self.device).to(vae_dtype)
345
  else:
346
  model.to(self.device).to(self.dtype)
 
350
 
351
  load_time = time.time() - start_time
352
  self.current_offload_cost += load_time
353
+ logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s")
354
 
355
  try:
356
  yield
357
  finally:
358
  # Offload to CPU
359
+ logger.info(f"[_load_model_context] Offloading {model_name} to CPU")
360
  start_time = time.time()
361
  model.to("cpu")
362
 
 
366
  torch.cuda.empty_cache()
367
  offload_time = time.time() - start_time
368
  self.current_offload_cost += offload_time
369
+ logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s")
370
 
371
  def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
372
  """Process target audio"""
 
382
  else:
383
  audio = torch.from_numpy(audio_np.T)
384
 
385
+ # Normalize to stereo 48kHz
386
+ audio = self._normalize_audio_to_stereo_48k(audio, sr)
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  return audio
389
  except Exception as e:
390
+ logger.exception("[process_target_audio] Error processing target audio")
391
  return None
392
 
393
  def _parse_audio_code_string(self, code_str: str) -> List[int]:
 
396
  return []
397
  try:
398
  return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)]
399
+ except Exception as e:
400
+ logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}")
401
  return []
402
 
403
  def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]:
 
524
  )
525
  """
526
  # Align instruction formatting with _prepare_batch
527
+ final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION)
 
 
528
 
529
  # Extract caption and language from metas if available (from LM CoT output)
530
  # Fallback to user-provided values if not in metas
 
555
 
556
  parsed_meta = self._parse_metas([metas])[0]
557
  caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta)
558
+ lyrics_input = self._format_lyrics(lyrics, actual_language)
559
  return caption_input, lyrics_input
560
 
561
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
598
  return match.group(1).strip()
599
  return caption
600
  except Exception as e:
601
+ logger.exception("[extract_caption_from_sft_format] Error extracting caption")
602
  return caption
603
 
604
  def prepare_seeds(self, actual_batch_size, seed, use_random_seed):
 
622
  else:
623
  try:
624
  seed_list.append(int(float(s)))
625
+ except (ValueError, TypeError) as e:
626
+ logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {e}")
627
  seed_list.append(-1)
628
  elif seed is None or (isinstance(seed, (int, float)) and seed < 0):
629
  # If seed is None or negative, use -1 for all items
 
664
  return actual_seed_list, seed_value_for_ui
665
 
666
  def prepare_metadata(self, bpm, key_scale, time_signature):
667
+ """Build metadata dict - use "N/A" as default for empty fields."""
668
+ return self._build_metadata_dict(bpm, key_scale, time_signature)
669
+
670
+ def is_silence(self, audio):
671
+ return torch.all(audio.abs() < 1e-6)
672
+
673
+ def _get_project_root(self) -> str:
674
+ """Get project root directory path."""
675
+ current_file = os.path.abspath(__file__)
676
+ return os.path.dirname(os.path.dirname(current_file))
677
+
678
+ def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype:
679
+ """Get VAE dtype based on device."""
680
+ device = device or self.device
681
+ return torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype
682
+
683
+ def _format_instruction(self, instruction: str) -> str:
684
+ """Format instruction to ensure it ends with colon."""
685
+ if not instruction.endswith(":"):
686
+ instruction = instruction + ":"
687
+ return instruction
688
+
689
+ def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor:
690
+ """
691
+ Normalize audio to stereo 48kHz format.
692
+
693
+ Args:
694
+ audio: Audio tensor [channels, samples] or [samples]
695
+ sr: Sample rate
696
+
697
+ Returns:
698
+ Normalized audio tensor [2, samples] at 48kHz
699
+ """
700
+ # Convert to stereo (duplicate channel if mono)
701
+ if audio.shape[0] == 1:
702
+ audio = torch.cat([audio, audio], dim=0)
703
+
704
+ # Keep only first 2 channels
705
+ audio = audio[:2]
706
+
707
+ # Resample to 48kHz if needed
708
+ if sr != 48000:
709
+ audio = torchaudio.transforms.Resample(sr, 48000)(audio)
710
+
711
+ # Clamp values to [-1.0, 1.0]
712
+ audio = torch.clamp(audio, -1.0, 1.0)
713
+
714
+ return audio
715
+
716
+ def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]:
717
+ """Normalize audio_code_hints to list of correct length."""
718
+ if audio_code_hints is None:
719
+ normalized = [None] * batch_size
720
+ elif isinstance(audio_code_hints, str):
721
+ normalized = [audio_code_hints] * batch_size
722
+ elif len(audio_code_hints) == 1 and batch_size > 1:
723
+ normalized = audio_code_hints * batch_size
724
+ elif len(audio_code_hints) != batch_size:
725
+ # Pad or truncate to match batch_size
726
+ normalized = list(audio_code_hints[:batch_size])
727
+ while len(normalized) < batch_size:
728
+ normalized.append(None)
729
+ else:
730
+ normalized = list(audio_code_hints)
731
+
732
+ # Clean up: convert empty strings to None
733
+ normalized = [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized]
734
+ return normalized
735
+
736
+ def _normalize_instructions(self, instructions: Optional[Union[str, List[str]]], batch_size: int, default: Optional[str] = None) -> List[str]:
737
+ """Normalize instructions to list of correct length."""
738
+ if instructions is None:
739
+ default_instruction = default or DEFAULT_DIT_INSTRUCTION
740
+ return [default_instruction] * batch_size
741
+ elif isinstance(instructions, str):
742
+ return [instructions] * batch_size
743
+ elif len(instructions) == 1:
744
+ return instructions * batch_size
745
+ elif len(instructions) != batch_size:
746
+ # Pad or truncate to match batch_size
747
+ normalized = list(instructions[:batch_size])
748
+ default_instruction = default or DEFAULT_DIT_INSTRUCTION
749
+ while len(normalized) < batch_size:
750
+ normalized.append(default_instruction)
751
+ return normalized
752
+ else:
753
+ return list(instructions)
754
+
755
+ def _format_lyrics(self, lyrics: str, language: str) -> str:
756
+ """Format lyrics text with language header."""
757
+ return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>"
758
+
759
+ def _pad_sequences(self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0) -> torch.Tensor:
760
+ """Pad sequences to same length."""
761
+ return torch.stack([
762
+ torch.nn.functional.pad(seq, (0, max_length - len(seq)), 'constant', pad_value)
763
+ for seq in sequences
764
+ ])
765
+
766
+ def _extract_caption_and_language(self, metas: List[Union[str, Dict[str, Any]]], captions: List[str], vocal_languages: List[str]) -> Tuple[List[str], List[str]]:
767
+ """Extract caption and language from metas with fallback to provided values."""
768
+ actual_captions = list(captions)
769
+ actual_languages = list(vocal_languages)
770
+
771
+ for i, meta in enumerate(metas):
772
+ if i >= len(actual_captions):
773
+ break
774
+
775
+ meta_dict = None
776
+ if isinstance(meta, str):
777
+ parsed = self._parse_metas([meta])
778
+ if parsed and isinstance(parsed[0], dict):
779
+ meta_dict = parsed[0]
780
+ elif isinstance(meta, dict):
781
+ meta_dict = meta
782
+
783
+ if meta_dict:
784
+ if 'caption' in meta_dict and meta_dict['caption']:
785
+ actual_captions[i] = str(meta_dict['caption'])
786
+ if 'language' in meta_dict and meta_dict['language']:
787
+ actual_languages[i] = str(meta_dict['language'])
788
+
789
+ return actual_captions, actual_languages
790
+
791
+ def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor:
792
+ """
793
+ Encode audio to latents using VAE.
794
+
795
+ Args:
796
+ audio: Audio tensor [channels, samples] or [batch, channels, samples]
797
+
798
+ Returns:
799
+ Latents tensor [T, D] or [batch, T, D]
800
+ """
801
+ # Ensure batch dimension
802
+ if audio.dim() == 2:
803
+ audio = audio.unsqueeze(0)
804
+
805
+ # Ensure input is in VAE's dtype
806
+ vae_input = audio.to(self.device).to(self.vae.dtype)
807
+
808
+ # Encode to latents
809
+ with torch.no_grad():
810
+ latents = self.vae.encode(vae_input).latent_dist.sample()
811
+
812
+ # Cast back to model dtype
813
+ latents = latents.to(self.dtype)
814
+
815
+ # Transpose: [batch, d, T] -> [batch, T, d]
816
+ latents = latents.transpose(1, 2)
817
+
818
+ # Remove batch dimension if input didn't have it
819
+ if audio.dim() == 2:
820
+ latents = latents.squeeze(0)
821
+
822
+ return latents
823
+
824
+ def _build_metadata_dict(self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str, duration: Optional[float] = None) -> Dict[str, Any]:
825
+ """
826
+ Build metadata dictionary with default values.
827
+
828
+ Args:
829
+ bpm: BPM value (optional)
830
+ key_scale: Key/scale string
831
+ time_signature: Time signature string
832
+ duration: Duration in seconds (optional)
833
+
834
+ Returns:
835
+ Metadata dictionary
836
+ """
837
  metadata_dict = {}
838
  if bpm:
839
  metadata_dict["bpm"] = bpm
 
849
  metadata_dict["timesignature"] = time_signature
850
  else:
851
  metadata_dict["timesignature"] = "N/A"
852
+
853
+ # Add duration if provided
854
+ if duration is not None:
855
+ metadata_dict["duration"] = f"{int(duration)} seconds"
856
+
857
  return metadata_dict
 
 
 
858
 
859
  def generate_instruction(
860
  self,
 
901
  # Load audio file
902
  audio, sr = torchaudio.load(audio_file)
903
 
904
+ logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
905
+ logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
906
+ logger.debug(f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / 48000.0} seconds")
 
 
 
 
907
 
908
+ # Normalize to stereo 48kHz
909
+ audio = self._normalize_audio_to_stereo_48k(audio, sr)
 
 
 
 
 
 
 
910
 
911
  is_silence = self.is_silence(audio)
912
  if is_silence:
 
945
  return audio
946
 
947
  except Exception as e:
948
+ logger.exception("[process_reference_audio] Error processing reference audio")
949
  return None
950
 
951
  def process_src_audio(self, audio_file) -> Optional[torch.Tensor]:
 
956
  # Load audio file
957
  audio, sr = torchaudio.load(audio_file)
958
 
959
+ # Normalize to stereo 48kHz
960
+ audio = self._normalize_audio_to_stereo_48k(audio, sr)
 
 
 
 
 
 
 
 
 
 
 
961
 
962
  return audio
963
 
964
  except Exception as e:
965
+ logger.exception("[process_src_audio] Error processing source audio")
966
  return None
967
 
968
  def convert_src_audio_to_codes(self, audio_file) -> str:
 
990
  # Encode audio to latents using VAE
991
  with torch.no_grad():
992
  with self._load_model_context("vae"):
 
 
 
993
  # Check if audio is silence
994
+ if self.is_silence(processed_audio.unsqueeze(0)):
995
  return "❌ Audio file appears to be silent"
996
 
997
+ # Encode to latents using helper method
998
+ latents = self._encode_audio_to_latents(processed_audio) # [T, d]
 
 
 
 
999
 
1000
  # Create attention mask for latents
1001
  attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
 
1020
 
1021
  except Exception as e:
1022
  error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
1023
+ logger.exception("[convert_src_audio_to_codes] Error converting audio to codes")
1024
  return error_msg
1025
 
1026
  def prepare_batch_data(
 
1049
  calculated_duration = audio_duration
1050
 
1051
  # Build metadata dict - use "N/A" as default for empty fields
1052
+ metadata_dict = self._build_metadata_dict(bpm, key_scale, time_signature, calculated_duration)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1053
 
1054
  # Format metadata - inference service accepts dict and will convert to string
1055
  # Create a copy for each batch item (in case we modify it)
 
1085
  target_wavs = torch.zeros(2, frames)
1086
  return target_wavs
1087
  except Exception as e:
1088
+ logger.exception("[create_target_wavs] Error creating target audio")
1089
  # Fallback to 30 seconds if error
1090
  return torch.zeros(2, 30 * 48000)
1091
 
 
1266
  """
1267
  batch_size = len(captions)
1268
 
1269
+ # Normalize audio_code_hints to batch list
1270
+ audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size)
 
 
 
 
 
 
 
 
1271
 
1272
  for ii, refer_audio_list in enumerate(refer_audios):
1273
  if isinstance(refer_audio_list, list):
 
1279
  if vocal_languages is None:
1280
  vocal_languages = self._create_fallback_vocal_languages(batch_size)
1281
 
 
 
 
 
 
 
 
 
 
 
 
1282
  # Parse metas with fallbacks
1283
  parsed_metas = self._parse_metas(metas)
1284
 
 
1312
  expected_latent_length = current_wav.shape[-1] // 1920
1313
  target_latent = self.silence_latent[0, :expected_latent_length, :]
1314
  else:
1315
+ # Encode using helper method
1316
  logger.info(f"[generate_music] Encoding target audio to latents for item {i}...")
1317
+ target_latent = self._encode_audio_to_latents(current_wav.squeeze(0)) # Remove batch dim for helper
 
 
 
 
1318
  target_latents_list.append(target_latent)
1319
  latent_lengths.append(target_latent.shape[0])
1320
 
 
1353
 
1354
  # Process instructions early so we can use them for task type detection
1355
  # Use custom instructions if provided, otherwise use default
1356
+ instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION)
 
 
 
 
 
 
 
 
 
 
 
1357
 
1358
  # Generate chunk_masks and spans based on repainting parameters
1359
  # Also determine if this is a cover task (target audio provided without repainting)
 
1502
  else:
1503
  precomputed_lm_hints_25Hz = None
1504
 
1505
+ # Extract caption and language from metas if available (from LM CoT output)
1506
+ # Fallback to user-provided values if not in metas
1507
+ actual_captions, actual_languages = self._extract_caption_and_language(parsed_metas, captions, vocal_languages)
1508
+
1509
  # Format text_inputs
1510
  text_inputs = []
1511
  text_token_idss = []
 
1515
 
1516
  for i in range(batch_size):
1517
  # Use custom instruction for this batch item
1518
+ instruction = self._format_instruction(instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION)
1519
+
1520
+ actual_caption = actual_captions[i]
1521
+ actual_language = actual_languages[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1522
 
1523
  # Format text prompt with custom instruction (using LM-generated caption if available)
1524
  text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
 
1535
  text_attention_mask = text_inputs_dict.attention_mask[0].bool()
1536
 
1537
  # Format and tokenize lyrics (using LM-generated language if available)
1538
+ lyrics_text = self._format_lyrics(lyrics[i], actual_language)
1539
  lyrics_inputs_dict = self.text_tokenizer(
1540
  lyrics_text,
1541
  padding="longest",
 
1557
 
1558
  # Pad tokenized sequences
1559
  max_text_length = max(len(seq) for seq in text_token_idss)
1560
+ padded_text_token_idss = self._pad_sequences(text_token_idss, max_text_length, self.text_tokenizer.pad_token_id)
1561
+ padded_text_attention_masks = self._pad_sequences(text_attention_masks, max_text_length, 0)
 
 
 
 
 
 
 
 
 
 
 
 
1562
 
1563
  max_lyric_length = max(len(seq) for seq in lyric_token_idss)
1564
+ padded_lyric_token_idss = self._pad_sequences(lyric_token_idss, max_lyric_length, self.text_tokenizer.pad_token_id)
1565
+ padded_lyric_attention_masks = self._pad_sequences(lyric_attention_masks, max_lyric_length, 0)
 
 
 
 
 
 
 
 
 
 
 
 
1566
 
1567
  padded_non_cover_text_input_ids = None
1568
  padded_non_cover_text_attention_masks = None
 
1571
  non_cover_text_attention_masks = []
1572
  for i in range(batch_size):
1573
  # Use custom instruction for this batch item
1574
+ instruction = self._format_instruction(DEFAULT_DIT_INSTRUCTION)
1575
 
1576
  # Extract caption from metas if available (from LM CoT output)
1577
+ actual_caption = actual_captions[i]
 
 
 
 
1578
 
1579
  # Format text prompt with custom instruction (using LM-generated caption if available)
1580
  text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i])
 
1592
  non_cover_text_input_ids.append(text_token_ids)
1593
  non_cover_text_attention_masks.append(non_cover_text_attention_mask)
1594
 
1595
+ padded_non_cover_text_input_ids = self._pad_sequences(non_cover_text_input_ids, max_text_length, self.text_tokenizer.pad_token_id)
1596
+ padded_non_cover_text_attention_masks = self._pad_sequences(non_cover_text_attention_masks, max_text_length, 0)
 
 
 
 
 
 
 
 
 
 
 
1597
 
1598
  if audio_cover_strength < 1.0:
1599
  assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None"
 
1827
  if self.config.is_turbo:
1828
  # Limit inference steps to maximum 8
1829
  if infer_steps > 8:
1830
+ logger.warning(f"[service_generate] dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8")
1831
  infer_steps = 8
1832
  # CFG parameters are not adjustable for dmd_gan (they will be ignored)
1833
  # Note: guidance_scale, cfg_interval_start, cfg_interval_end are still passed but may be ignored by the model
 
1850
  if isinstance(repainting_end, (int, float)):
1851
  repainting_end = [repainting_end]
1852
 
 
 
 
 
 
 
 
 
 
 
 
 
1853
  # Get batch size from captions
1854
  batch_size = len(captions)
1855
 
1856
+ # Normalize instructions and audio_code_hints to match batch size
1857
+ instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) if instructions is not None else None
1858
+ audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) if audio_code_hints is not None else None
 
 
 
 
 
 
1859
 
1860
  # Convert seed to list format
1861
  if seed is None:
 
1952
  logger.info("[service_generate] Generating audio...")
1953
  with self._load_model_context("model"):
1954
  outputs = self.model.generate_audio(**generate_kwargs)
1955
+
1956
+ # Add intermediate information to outputs for extra_outputs
1957
+ outputs["src_latents"] = src_latents
1958
+ outputs["target_latents_input"] = target_latents # Input target latents (before generation)
1959
+ outputs["chunk_masks"] = chunk_mask
1960
+ outputs["spans"] = spans
1961
+ outputs["latent_masks"] = batch.get("latent_masks") # Latent masks for valid length
1962
+
1963
  return outputs
1964
 
1965
  def tiled_decode(self, latents, chunk_size=512, overlap=64):
 
2055
  use_adg: bool = False,
2056
  cfg_interval_start: float = 0.0,
2057
  cfg_interval_end: float = 1.0,
 
 
2058
  use_tiled_decode: bool = True,
2059
  progress=None
2060
+ ) -> Dict[str, Any]:
2061
  """
2062
  Main interface for music generation
2063
 
2064
  Returns:
2065
+ Dictionary containing:
2066
+ - audios: List of audio dictionaries with path, key, params
2067
+ - generation_info: Markdown-formatted generation information
2068
+ - status_message: Status message
2069
+ - extra_outputs: Dictionary with latents, masks, time_costs, etc.
2070
+ - success: Whether generation completed successfully
2071
+ - error: Error message if generation failed
2072
  """
2073
  if progress is None:
2074
  def progress(*args, **kwargs):
2075
  pass
2076
 
2077
  if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None:
2078
+ return {
2079
+ "audios": [],
2080
+ "generation_info": "",
2081
+ "status_message": "❌ Model not fully initialized. Please initialize all components first.",
2082
+ "extra_outputs": {},
2083
+ "success": False,
2084
+ "error": "Model not fully initialized",
2085
+ }
2086
 
2087
  def _has_audio_codes(v: Union[str, List[str]]) -> bool:
2088
  if isinstance(v, list):
 
2213
  pred_latents = outputs["target_latents"] # [batch, latent_length, latent_dim]
2214
  time_costs = outputs["time_costs"]
2215
  time_costs["offload_time_cost"] = self.current_offload_cost
2216
+ logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}")
2217
+ logger.debug(f"[generate_music] time_costs: {time_costs}")
2218
  if progress:
2219
  progress(0.8, desc="Decoding audio...")
2220
  logger.info("[generate_music] Decoding latents with VAE...")
 
2243
  # Update offload cost one last time to include VAE offloading
2244
  time_costs["offload_time_cost"] = self.current_offload_cost
2245
 
2246
+ logger.info("[generate_music] VAE decode completed. Preparing audio tensors...")
2247
  if progress:
2248
+ progress(0.9, desc="Preparing audio data...")
2249
 
2250
+ # Prepare audio tensors (no file I/O here, no UUID generation)
2251
+ # pred_wavs is already [batch, channels, samples] format
2252
+ # Move to CPU and convert to float32 for return
2253
+ audio_tensors = []
2254
 
 
 
2255
  for i in range(actual_batch_size):
2256
+ # Extract audio tensor: [channels, samples] format, CPU, float32
2257
+ audio_tensor = pred_wavs[i].cpu().float()
2258
+ audio_tensors.append(audio_tensor)
 
 
 
 
 
 
 
 
 
2259
 
2260
  # Format time costs if available
2261
  time_costs_str = ""
 
2273
 
2274
  **Seeds:** {seed_value_for_ui}
2275
  **Steps:** {inference_steps}
2276
+ **Audio Count:** {len(audio_tensors)} audio(s){time_costs_str}"""
2277
  status_message = f"✅ Generation completed successfully!"
2278
+ logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.")
2279
+
2280
+ # Extract intermediate information from outputs
2281
+ src_latents = outputs.get("src_latents") # [batch, T, D]
2282
+ target_latents_input = outputs.get("target_latents_input") # [batch, T, D]
2283
+ chunk_masks = outputs.get("chunk_masks") # [batch, T]
2284
+ spans = outputs.get("spans", []) # List of tuples
2285
+ latent_masks = outputs.get("latent_masks") # [batch, T]
2286
+
2287
+ # Move latents to CPU to save memory (they can be large)
2288
+ extra_outputs = {
2289
+ "pred_latents": pred_latents.cpu() if pred_latents is not None else None,
2290
+ "target_latents": target_latents_input.cpu() if target_latents_input is not None else None,
2291
+ "src_latents": src_latents.cpu() if src_latents is not None else None,
2292
+ "chunk_masks": chunk_masks.cpu() if chunk_masks is not None else None,
2293
+ "latent_masks": latent_masks.cpu() if latent_masks is not None else None,
2294
+ "spans": spans,
2295
+ "time_costs": time_costs,
2296
+ "seed_value": seed_value_for_ui,
2297
+ }
2298
+
2299
+ # Build audios list with tensor data (no file paths, no UUIDs, handled outside)
2300
+ audios = []
2301
+ for idx, audio_tensor in enumerate(audio_tensors):
2302
+ audio_dict = {
2303
+ "tensor": audio_tensor, # torch.Tensor [channels, samples], CPU, float32
2304
+ "sample_rate": self.sample_rate,
2305
+ }
2306
+ audios.append(audio_dict)
2307
+
2308
+ return {
2309
+ "audios": audios,
2310
+ "generation_info": generation_info,
2311
+ "status_message": status_message,
2312
+ "extra_outputs": extra_outputs,
2313
+ "success": True,
2314
+ "error": None,
2315
+ }
2316
 
2317
  except Exception as e:
2318
  error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
2319
+ logger.exception("[generate_music] Generation failed")
2320
+ return {
2321
+ "audios": [],
2322
+ "generation_info": "",
2323
+ "status_message": error_msg,
2324
+ "extra_outputs": {},
2325
+ "success": False,
2326
+ "error": str(e),
2327
+ }
acestep/inference.py CHANGED
@@ -7,105 +7,100 @@ backward-compatible Gradio UI support.
7
  """
8
 
9
  import math
 
 
10
  from typing import Optional, Union, List, Dict, Any, Tuple
11
  from dataclasses import dataclass, field, asdict
12
  from loguru import logger
13
- import time as time_module
 
14
 
15
 
16
  @dataclass
17
- class GenerationConfig:
18
- """Configuration for music generation.
19
 
20
  Attributes:
21
  # Text Inputs
22
- caption: Text description of the desired music
23
- lyrics: Lyrics text for vocal music (use "[Instrumental]" for instrumental)
 
24
 
25
  # Music Metadata
26
- bpm: Beats per minute (e.g., 120). None for auto-detection
27
- key_scale: Musical key (e.g., "C Major", "Am"). Empty for auto-detection
28
- time_signature: Time signature (e.g., "4/4", "3/4"). Empty for auto-detection
29
- vocal_language: Language code for vocals (e.g., "en", "zh", "ja")
30
- audio_duration: Duration in seconds. None for auto-detection
31
 
32
  # Generation Parameters
33
- inference_steps: Number of denoising steps (8 for turbo, 32-100 for base)
34
- guidance_scale: Classifier-free guidance scale (higher = more adherence to prompt)
35
- use_random_seed: Whether to use random seed (True) or fixed seed
36
- seed: Random seed for reproducibility (-1 for random)
37
- batch_size: Number of samples to generate (1-8)
38
 
39
  # Advanced DiT Parameters
40
- use_adg: Use Adaptive Dual Guidance (base model only)
41
- cfg_interval_start: CFG application start ratio (0.0-1.0)
42
- cfg_interval_end: CFG application end ratio (0.0-1.0)
43
- audio_format: Output audio format ("mp3", "wav", "flac")
44
 
45
  # Task-Specific Parameters
46
- task_type: Generation task type ("text2music", "cover", "repaint", "lego", "extract", "complete")
47
- reference_audio: Path to reference audio file (for style transfer)
48
- src_audio: Path to source audio file (for audio-to-audio tasks)
49
- audio_code_string: Pre-extracted audio codes (advanced use)
50
- repainting_start: Repainting start time in seconds (for repaint/lego tasks)
51
- repainting_end: Repainting end time in seconds (-1 for end of audio)
52
- audio_cover_strength: Strength of audio cover/codes influence (0.0-1.0)
53
- instruction: Task-specific instruction prompt (auto-generated if empty)
54
 
55
- # 5Hz Language Model Parameters (CoT Reasoning)
56
- use_llm_thinking: Enable LM-based Chain-of-Thought reasoning for metadata/codes
57
- lm_temperature: LM sampling temperature (0.0-2.0, higher = more creative)
58
- lm_cfg_scale: LM classifier-free guidance scale
59
- lm_top_k: LM top-k sampling (0 = disabled)
60
- lm_top_p: LM nucleus sampling (1.0 = disabled)
61
- lm_negative_prompt: Negative prompt for LM guidance
62
- use_cot_metas: Generate metadata using LM CoT
63
- use_cot_caption: Refine caption using LM CoT
64
- use_cot_language: Detect language using LM CoT
65
- is_format_caption: Whether caption is already formatted
66
- constrained_decoding_debug: Enable debug logging for constrained decoding
67
-
68
- # Batch LM Generation
69
- allow_lm_batch: Allow batch LM code generation (faster for batch_size >= 2)
70
- lm_batch_chunk_size: Maximum batch size per LM inference chunk (GPU memory constraint)
71
  """
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Text Inputs
74
  caption: str = ""
75
  lyrics: str = ""
 
76
 
77
- # Music Metadata
78
- bpm: Optional[int] = None
79
- key_scale: str = ""
80
- time_signature: str = ""
81
  vocal_language: str = "unknown"
82
- audio_duration: Optional[float] = None
83
-
84
- # Generation Parameters
 
 
 
85
  inference_steps: int = 8
86
- guidance_scale: float = 7.0
87
- use_random_seed: bool = True
88
  seed: int = -1
89
- batch_size: int = 1
90
-
91
- # Advanced DiT Parameters
92
  use_adg: bool = False
93
  cfg_interval_start: float = 0.0
94
  cfg_interval_end: float = 1.0
95
- audio_format: str = "mp3"
96
-
97
- # Task-Specific Parameters
98
- task_type: str = "text2music"
99
- reference_audio: Optional[str] = None
100
- src_audio: Optional[str] = None
101
- audio_code_string: Union[str, List[str]] = ""
102
  repainting_start: float = 0.0
103
  repainting_end: float = -1
104
  audio_cover_strength: float = 1.0
105
- instruction: str = ""
106
 
107
  # 5Hz Language Model Parameters
108
- use_llm_thinking: bool = False
109
  lm_temperature: float = 0.85
110
  lm_cfg_scale: float = 2.0
111
  lm_top_k: int = 0
@@ -114,66 +109,59 @@ class GenerationConfig:
114
  use_cot_metas: bool = True
115
  use_cot_caption: bool = True
116
  use_cot_language: bool = True
117
- is_format_caption: bool = False
118
- constrained_decoding_debug: bool = False
119
 
120
- # Batch LM Generation
121
- allow_lm_batch: bool = False
122
- lm_batch_chunk_size: int = 4
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  @dataclass
126
  class GenerationResult:
127
  """Result of music generation.
128
 
129
  Attributes:
130
  # Audio Outputs
131
- audio_paths: List of paths to generated audio files
132
- first_audio: Path to first generated audio (backward compatibility)
133
- second_audio: Path to second generated audio (backward compatibility)
134
-
135
- # Generation Information
136
  generation_info: Markdown-formatted generation information
137
  status_message: Status message from generation
138
- seed_value: Actual seed value used for generation
139
-
140
- # LM-Generated Metadata (if applicable)
141
- lm_metadata: Metadata generated by language model (dict or None)
142
-
143
- # Audio-Text Alignment Scores (if available)
144
- align_score_1: First alignment score
145
- align_text_1: First alignment text description
146
- align_plot_1: First alignment plot image
147
- align_score_2: Second alignment score
148
- align_text_2: Second alignment text description
149
- align_plot_2: Second alignment plot image
150
-
151
- # Success Status
152
  success: Whether generation completed successfully
153
  error: Error message if generation failed
154
  """
155
 
156
  # Audio Outputs
157
- audio_paths: List[str] = field(default_factory=list)
158
- first_audio: Optional[str] = None
159
- second_audio: Optional[str] = None
160
-
161
  # Generation Information
162
  generation_info: str = ""
163
  status_message: str = ""
164
- seed_value: str = ""
165
-
166
- # LM-Generated Metadata
167
- lm_metadata: Optional[Dict[str, Any]] = None
168
-
169
- # Audio-Text Alignment Scores
170
- align_score_1: Optional[float] = None
171
- align_text_1: Optional[str] = None
172
- align_plot_1: Optional[Any] = None
173
- align_score_2: Optional[float] = None
174
- align_text_2: Optional[str] = None
175
- align_plot_2: Optional[Any] = None
176
-
177
  # Success Status
178
  success: bool = True
179
  error: Optional[str] = None
@@ -186,75 +174,71 @@ class GenerationResult:
186
  def generate_music(
187
  dit_handler,
188
  llm_handler,
 
189
  config: GenerationConfig,
 
190
  ) -> GenerationResult:
191
  """Generate music using ACE-Step model with optional LM reasoning.
192
 
193
- This is the main inference API for music generation. It supports various task types
194
- (text2music, cover, repaint, etc.) and can optionally use a 5Hz Language Model for
195
- Chain-of-Thought reasoning to generate metadata and audio codes.
196
-
197
  Args:
198
  dit_handler: Initialized DiT model handler (AceStepHandler instance)
199
  llm_handler: Initialized LLM handler (LLMHandler instance)
 
200
  config: Generation configuration (GenerationConfig instance)
201
 
202
  Returns:
203
- GenerationResult: Generation result containing audio paths and metadata
204
-
205
- Example:
206
- >>> from acestep.handler import AceStepHandler
207
- >>> from acestep.llm_inference import LLMHandler
208
- >>> from acestep.inference import GenerationConfig, generate_music
209
- >>>
210
- >>> # Initialize handlers
211
- >>> dit_handler = AceStepHandler()
212
- >>> llm_handler = LLMHandler()
213
- >>> dit_handler.initialize_service(...)
214
- >>> llm_handler.initialize(...)
215
- >>>
216
- >>> # Configure generation
217
- >>> config = GenerationConfig(
218
- ... caption="upbeat electronic dance music",
219
- ... bpm=128,
220
- ... audio_duration=30,
221
- ... batch_size=2,
222
- ... )
223
- >>>
224
- >>> # Generate music
225
- >>> result = generate_music(dit_handler, llm_handler, config)
226
- >>> print(f"Generated {len(result.audio_paths)} audio files")
227
- >>> for path in result.audio_paths:
228
- ... print(f"Audio: {path}")
229
  """
230
-
231
  try:
232
  # Phase 1: LM-based metadata and code generation (if enabled)
233
- audio_code_string_to_use = config.audio_code_string
234
  lm_generated_metadata = None
235
- lm_generated_audio_codes = None
236
  lm_generated_audio_codes_list = []
237
 
238
  # Extract mutable copies of metadata (will be updated by LM if needed)
239
- bpm = config.bpm
240
- key_scale = config.key_scale
241
- time_signature = config.time_signature
242
- audio_duration = config.audio_duration
243
 
244
- # Determine if we should use batch LM generation
245
- should_use_lm_batch = (
246
- config.use_llm_thinking
247
- and llm_handler.llm_initialized
248
- and config.use_cot_metas
249
- and config.allow_lm_batch
250
- and config.batch_size >= 2
251
- )
 
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  # LM-based Chain-of-Thought reasoning
254
- if config.use_llm_thinking and llm_handler.llm_initialized and config.use_cot_metas:
255
  # Convert sampling parameters
256
- top_k_value = None if config.lm_top_k == 0 else int(config.lm_top_k)
257
- top_p_value = None if config.lm_top_p >= 1.0 else config.lm_top_p
258
 
259
  # Build user_metadata from user-provided values
260
  user_metadata = {}
@@ -286,165 +270,231 @@ def generate_music(
286
 
287
  user_metadata_to_pass = user_metadata if user_metadata else None
288
 
289
- # Batch LM generation (faster for multiple samples)
290
- if should_use_lm_batch:
291
- actual_seed_list, _ = dit_handler.prepare_seeds(
292
- config.batch_size, config.seed, config.use_random_seed
293
- )
294
-
295
- max_inference_batch_size = int(config.lm_batch_chunk_size)
296
- num_chunks = math.ceil(config.batch_size / max_inference_batch_size)
297
-
298
- all_metadata_list = []
299
- all_audio_codes_list = []
300
-
301
- for chunk_idx in range(num_chunks):
302
- chunk_start = chunk_idx * max_inference_batch_size
303
- chunk_end = min(chunk_start + max_inference_batch_size, config.batch_size)
304
- chunk_size = chunk_end - chunk_start
305
- chunk_seeds = actual_seed_list[chunk_start:chunk_end]
306
-
307
- logger.info(
308
- f"LM batch chunk {chunk_idx+1}/{num_chunks} "
309
- f"(size: {chunk_size}, seeds: {chunk_seeds})"
310
- )
311
-
312
- metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
313
- caption=config.caption or "",
314
- lyrics=config.lyrics or "",
315
- batch_size=chunk_size,
316
- infer_type="llm_dit",
317
- temperature=config.lm_temperature,
318
- cfg_scale=config.lm_cfg_scale,
319
- negative_prompt=config.lm_negative_prompt,
320
- top_k=top_k_value,
321
- top_p=top_p_value,
322
- user_metadata=user_metadata_to_pass,
323
- use_cot_caption=config.use_cot_caption,
324
- use_cot_language=config.use_cot_language,
325
- is_format_caption=config.is_format_caption,
326
- constrained_decoding_debug=config.constrained_decoding_debug,
327
- seeds=chunk_seeds,
328
- )
329
-
330
- all_metadata_list.extend(metadata_list)
331
- all_audio_codes_list.extend(audio_codes_list)
332
-
333
- lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
334
- lm_generated_audio_codes_list = all_audio_codes_list
335
- audio_code_string_to_use = all_audio_codes_list
336
 
337
- # Update metadata from LM if not provided by user
338
- if lm_generated_metadata:
339
- bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
340
- lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
341
- )
342
-
343
- else:
344
- # Sequential LM generation (current behavior)
345
- # Phase 1: Generate CoT metadata
346
- phase1_start = time_module.time()
347
- metadata, _, status = llm_handler.generate_with_stop_condition(
348
- caption=config.caption or "",
349
- lyrics=config.lyrics or "",
350
- infer_type="dit",
351
- temperature=config.lm_temperature,
352
- cfg_scale=config.lm_cfg_scale,
353
- negative_prompt=config.lm_negative_prompt,
354
- top_k=top_k_value,
355
- top_p=top_p_value,
356
- user_metadata=user_metadata_to_pass,
357
- use_cot_caption=config.use_cot_caption,
358
- use_cot_language=config.use_cot_language,
359
- is_format_caption=config.is_format_caption,
360
- constrained_decoding_debug=config.constrained_decoding_debug,
361
  )
362
- lm_phase1_time = time_module.time() - phase1_start
363
- logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
364
 
365
- # Phase 2: Generate audio codes
366
- phase2_start = time_module.time()
367
- metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
368
- caption=config.caption or "",
369
- lyrics=config.lyrics or "",
370
- infer_type="llm_dit",
371
- temperature=config.lm_temperature,
372
- cfg_scale=config.lm_cfg_scale,
373
- negative_prompt=config.lm_negative_prompt,
 
374
  top_k=top_k_value,
375
  top_p=top_p_value,
376
  user_metadata=user_metadata_to_pass,
377
- use_cot_caption=config.use_cot_caption,
378
- use_cot_language=config.use_cot_language,
379
  is_format_caption=config.is_format_caption,
 
380
  constrained_decoding_debug=config.constrained_decoding_debug,
 
 
381
  )
382
- lm_phase2_time = time_module.time() - phase2_start
383
- logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
384
 
385
- lm_generated_metadata = metadata
386
- if audio_codes:
387
- audio_code_string_to_use = audio_codes
388
- lm_generated_audio_codes = audio_codes
389
-
390
- # Update metadata from LM if not provided by user
391
- bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
392
- metadata, bpm, key_scale, time_signature, audio_duration
393
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
  # Phase 2: DiT music generation
 
396
  result = dit_handler.generate_music(
397
- captions=config.caption,
398
- lyrics=config.lyrics,
399
  bpm=bpm,
400
  key_scale=key_scale,
401
  time_signature=time_signature,
402
- vocal_language=config.vocal_language,
403
- inference_steps=config.inference_steps,
404
- guidance_scale=config.guidance_scale,
405
  use_random_seed=config.use_random_seed,
406
- seed=config.seed,
407
- reference_audio=config.reference_audio,
408
  audio_duration=audio_duration,
409
- batch_size=config.batch_size,
410
- src_audio=config.src_audio,
411
  audio_code_string=audio_code_string_to_use,
412
- repainting_start=config.repainting_start,
413
- repainting_end=config.repainting_end,
414
- instruction=config.instruction,
415
- audio_cover_strength=config.audio_cover_strength,
416
- task_type=config.task_type,
417
- use_adg=config.use_adg,
418
- cfg_interval_start=config.cfg_interval_start,
419
- cfg_interval_end=config.cfg_interval_end,
420
- audio_format=config.audio_format,
421
- lm_temperature=config.lm_temperature,
422
  )
423
 
424
- # Extract results
425
- (first_audio, second_audio, all_audio_paths, generation_info, status_message,
426
- seed_value, align_score_1, align_text_1, align_plot_1,
427
- align_score_2, align_text_2, align_plot_2) = result
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  # Append LM metadata to generation info
430
  if lm_generated_metadata:
431
  generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
432
 
433
- # Create result object
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  return GenerationResult(
435
- audio_paths=all_audio_paths or [],
436
- first_audio=first_audio,
437
- second_audio=second_audio,
438
  generation_info=generation_info,
439
  status_message=status_message,
440
- seed_value=seed_value,
441
- lm_metadata=lm_generated_metadata,
442
- align_score_1=align_score_1,
443
- align_text_1=align_text_1,
444
- align_plot_1=align_plot_1,
445
- align_score_2=align_score_2,
446
- align_text_2=align_text_2,
447
- align_plot_2=align_plot_2,
448
  success=True,
449
  error=None,
450
  )
@@ -452,10 +502,12 @@ def generate_music(
452
  except Exception as e:
453
  logger.exception("Music generation failed")
454
  return GenerationResult(
455
- success=False,
456
- error=str(e),
457
  generation_info=f"❌ Generation failed: {str(e)}",
458
  status_message=f"Error: {str(e)}",
 
 
 
459
  )
460
 
461
 
@@ -525,7 +577,7 @@ def _append_lm_metadata_to_info(generation_info: str, metadata: Dict[str, Any])
525
  # LEGACY GRADIO UI COMPATIBILITY LAYER
526
  # ============================================================================
527
 
528
- def generate(
529
  dit_handler,
530
  llm_handler,
531
  captions,
@@ -575,20 +627,19 @@ def generate(
575
  Tuple with 28 elements for Gradio UI component updates
576
  """
577
 
578
- # Convert legacy parameters to new config
579
- config = GenerationConfig(
580
  caption=captions,
581
  lyrics=lyrics,
582
  bpm=bpm,
583
- key_scale=key_scale,
584
- time_signature=time_signature,
585
  vocal_language=vocal_language,
586
- audio_duration=audio_duration,
 
587
  inference_steps=inference_steps,
588
  guidance_scale=guidance_scale,
589
- use_random_seed=random_seed_checkbox,
590
  seed=seed,
591
- batch_size=batch_size_input,
592
  use_adg=use_adg,
593
  cfg_interval_start=cfg_interval_start,
594
  cfg_interval_end=cfg_interval_end,
@@ -596,12 +647,11 @@ def generate(
596
  task_type=task_type,
597
  reference_audio=reference_audio,
598
  src_audio=src_audio,
599
- audio_code_string=text2music_audio_code_string,
600
  repainting_start=repainting_start,
601
  repainting_end=repainting_end,
602
  audio_cover_strength=audio_cover_strength,
603
  instruction=instruction_display_gen,
604
- use_llm_thinking=think_checkbox,
605
  lm_temperature=lm_temperature,
606
  lm_cfg_scale=lm_cfg_scale,
607
  lm_top_k=lm_top_k,
@@ -610,29 +660,49 @@ def generate(
610
  use_cot_metas=use_cot_metas,
611
  use_cot_caption=use_cot_caption,
612
  use_cot_language=use_cot_language,
613
- is_format_caption=is_format_caption,
614
- constrained_decoding_debug=constrained_decoding_debug,
615
- allow_lm_batch=allow_lm_batch,
616
- lm_batch_chunk_size=lm_batch_chunk_size,
617
  )
618
 
 
 
 
 
 
 
 
 
619
  # Call new API
620
- result = generate_music(dit_handler, llm_handler, config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
  # Determine which codes to update in UI
623
- if config.allow_lm_batch and result.lm_metadata:
624
  # Batch mode: extract codes from metadata if available
625
- lm_codes_list = result.lm_metadata.get('audio_codes_list', [])
626
  updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
627
  codes_outputs = (lm_codes_list + [""] * 8)[:8]
628
  else:
629
  # Single mode
630
- lm_codes = result.lm_metadata.get('audio_codes', '') if result.lm_metadata else ''
631
  updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
632
  codes_outputs = [""] * 8
633
 
634
  # Prepare audio outputs (up to 8)
635
- audio_outputs = (result.audio_paths + [None] * 8)[:8]
636
 
637
  # Return tuple for Gradio UI (28 elements)
638
  return (
@@ -644,16 +714,16 @@ def generate(
644
  audio_outputs[5], # generated_audio_6
645
  audio_outputs[6], # generated_audio_7
646
  audio_outputs[7], # generated_audio_8
647
- result.audio_paths, # generated_audio_batch
648
  result.generation_info,
649
  result.status_message,
650
- result.seed_value,
651
- result.align_score_1,
652
- result.align_text_1,
653
- result.align_plot_1,
654
- result.align_score_2,
655
- result.align_text_2,
656
- result.align_plot_2,
657
  updated_audio_codes, # Update main audio codes in UI
658
  codes_outputs[0], # text2music_audio_code_string_1
659
  codes_outputs[1], # text2music_audio_code_string_2
@@ -663,266 +733,8 @@ def generate(
663
  codes_outputs[5], # text2music_audio_code_string_6
664
  codes_outputs[6], # text2music_audio_code_string_7
665
  codes_outputs[7], # text2music_audio_code_string_8
666
- result.lm_metadata, # Store metadata for "Send to src audio" buttons
667
  is_format_caption, # Keep is_format_caption unchanged
668
  )
669
 
670
 
671
- # ============================================================================
672
- # TESTING & EXAMPLES
673
- # ============================================================================
674
-
675
- if __name__ == "__main__":
676
- """
677
- Test suite for the inference API.
678
- Demonstrates various usage scenarios and validates functionality.
679
-
680
- Usage:
681
- python -m acestep.inference
682
- """
683
-
684
- import os
685
- import json
686
- from acestep.handler import AceStepHandler
687
- from acestep.llm_inference import LLMHandler
688
-
689
- # Initialize paths
690
- project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
691
- checkpoint_dir = os.path.join(project_root, "checkpoints")
692
-
693
- print("=" * 80)
694
- print("ACE-Step Inference API Test Suite")
695
- print("=" * 80)
696
-
697
- # ========================================================================
698
- # Initialize Handlers
699
- # ========================================================================
700
- print("\n[1/3] Initializing handlers...")
701
- dit_handler = AceStepHandler(save_root="./")
702
- llm_handler = LLMHandler()
703
-
704
- try:
705
- # Initialize DiT handler
706
- print(" - Initializing DiT model...")
707
- status_dit, success_dit = dit_handler.initialize_service(
708
- project_root=project_root,
709
- config_path="acestep-v15-turbo-rl",
710
- device="cuda",
711
- )
712
- if not success_dit:
713
- print(f" ❌ DiT initialization failed: {status_dit}")
714
- exit(1)
715
- print(f" ✓ DiT model initialized successfully")
716
-
717
- # Initialize LLM handler
718
- print(" - Initializing 5Hz LM model...")
719
- status_llm, success_llm = llm_handler.initialize(
720
- checkpoint_dir=checkpoint_dir,
721
- lm_model_path="acestep-5Hz-lm-0.6B-v3",
722
- backend="vllm",
723
- device="cuda",
724
- )
725
- if success_llm:
726
- print(f" ✓ LM model initialized successfully")
727
- else:
728
- print(f" ⚠ LM initialization failed (will skip LM tests): {status_llm}")
729
-
730
- except Exception as e:
731
- print(f" ❌ Initialization error: {e}")
732
- exit(1)
733
-
734
- # ========================================================================
735
- # Helper Functions
736
- # ========================================================================
737
- def load_example_config(example_file: str) -> GenerationConfig:
738
- """Load configuration from an example JSON file."""
739
- try:
740
- with open(example_file, 'r', encoding='utf-8') as f:
741
- data = json.load(f)
742
-
743
- # Convert example format to GenerationConfig
744
- # Handle time signature format (example uses "4" instead of "4/4")
745
- time_sig = data.get('timesignature', '')
746
- if time_sig and '/' not in time_sig:
747
- time_sig = f"{time_sig}/4" # Default to /4 if only numerator given
748
-
749
- config = GenerationConfig(
750
- caption=data.get('caption', ''),
751
- lyrics=data.get('lyrics', ''),
752
- bpm=data.get('bpm'),
753
- key_scale=data.get('keyscale', ''),
754
- time_signature=time_sig,
755
- vocal_language=data.get('language', 'unknown'),
756
- audio_duration=data.get('duration'),
757
- use_llm_thinking=data.get('think', False),
758
- batch_size=data.get('batch_size', 1),
759
- inference_steps=data.get('inference_steps', 8),
760
- )
761
- return config
762
-
763
- except Exception as e:
764
- print(f" ⚠ Failed to load example file: {e}")
765
- return None
766
-
767
- # ========================================================================
768
- # Test Cases
769
- # ========================================================================
770
- test_results = []
771
-
772
- def run_test(test_name: str, config: GenerationConfig, expected_outputs: int = 1):
773
- """Run a single test case and collect results."""
774
- print(f"\n{'=' * 80}")
775
- print(f"Test: {test_name}")
776
- print(f"{'=' * 80}")
777
-
778
- # Display configuration
779
- print("\nConfiguration:")
780
- print(f" Task Type: {config.task_type}")
781
- print(f" Caption: {config.caption[:60]}..." if len(config.caption) > 60 else f" Caption: {config.caption}")
782
- if config.lyrics:
783
- print(f" Lyrics: {config.lyrics[:60]}..." if len(config.lyrics) > 60 else f" Lyrics: {config.lyrics}")
784
- if config.bpm:
785
- print(f" BPM: {config.bpm}")
786
- if config.key_scale:
787
- print(f" Key Scale: {config.key_scale}")
788
- if config.time_signature:
789
- print(f" Time Signature: {config.time_signature}")
790
- if config.audio_duration:
791
- print(f" Duration: {config.audio_duration}s")
792
- print(f" Batch Size: {config.batch_size}")
793
- print(f" Inference Steps: {config.inference_steps}")
794
- print(f" Use LLM Thinking: {config.use_llm_thinking}")
795
-
796
- # Run generation
797
- print("\nGenerating...")
798
- import time
799
- start_time = time.time()
800
-
801
- result = generate_music(dit_handler, llm_handler, config)
802
-
803
- elapsed_time = time.time() - start_time
804
-
805
- # Display results
806
- print("\nResults:")
807
- print(f" Success: {'✓' if result.success else '✗'}")
808
-
809
- if result.success:
810
- print(f" Generated Files: {len(result.audio_paths)}")
811
- for i, path in enumerate(result.audio_paths, 1):
812
- if os.path.exists(path):
813
- file_size = os.path.getsize(path) / (1024 * 1024) # MB
814
- print(f" [{i}] {os.path.basename(path)} ({file_size:.2f} MB)")
815
- else:
816
- print(f" [{i}] {os.path.basename(path)} (file not found)")
817
-
818
- print(f" Seed: {result.seed_value}")
819
- print(f" Generation Time: {elapsed_time:.2f}s")
820
-
821
- # Display LM metadata if available
822
- if result.lm_metadata:
823
- print(f"\n LM-Generated Metadata:")
824
- for key, value in result.lm_metadata.items():
825
- if key not in ['audio_codes', 'audio_codes_list']: # Skip large code strings
826
- print(f" {key}: {value}")
827
-
828
- # Validate outputs
829
- if len(result.audio_paths) != expected_outputs:
830
- print(f" ⚠ Warning: Expected {expected_outputs} outputs, got {len(result.audio_paths)}")
831
- success = False
832
- else:
833
- success = True
834
-
835
- else:
836
- print(f" Error: {result.error}")
837
- success = False
838
-
839
- # Store test result
840
- test_results.append({
841
- "test_name": test_name,
842
- "success": success,
843
- "generation_success": result.success,
844
- "num_outputs": len(result.audio_paths) if result.success else 0,
845
- "expected_outputs": expected_outputs,
846
- "elapsed_time": elapsed_time,
847
- "error": result.error if not result.success else None,
848
- })
849
-
850
- return result
851
-
852
- # ========================================================================
853
- # Test: Production Example (from examples directory)
854
- # ========================================================================
855
- print("\n[2/3] Running Test...")
856
-
857
- # Load production example (J-Rock song from examples/text2music/example_05.json)
858
- example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
859
-
860
- if not os.path.exists(example_file):
861
- print(f"\n ❌ Example file not found: {example_file}")
862
- print(" Please ensure the examples directory exists.")
863
- exit(1)
864
-
865
- print(f" Loading example: {os.path.basename(example_file)}")
866
- config = load_example_config(example_file)
867
-
868
- if not config:
869
- print(" ❌ Failed to load example configuration")
870
- exit(1)
871
-
872
- # Reduce duration for faster testing (original is 200s)
873
- print(f" Original duration: {config.audio_duration}s")
874
- config.audio_duration = 30
875
- config.use_random_seed = False
876
- config.seed = 42
877
- print(f" Test duration: {config.audio_duration}s (reduced for testing)")
878
-
879
- run_test("Production Example (J-Rock Song)", config, expected_outputs=1)
880
-
881
- # ========================================================================
882
- # Test Summary
883
- # ========================================================================
884
- print("\n[3/3] Test Summary")
885
- print("=" * 80)
886
-
887
- if len(test_results) == 0:
888
- print("No tests were run.")
889
- exit(1)
890
-
891
- result = test_results[0]
892
-
893
- print(f"\nTest: {result['test_name']}")
894
- print(f"Status: {'✓ PASS' if result['success'] else '✗ FAIL'}")
895
- print(f"Generation: {'Success' if result['generation_success'] else 'Failed'}")
896
- print(f"Outputs: {result['num_outputs']}/{result['expected_outputs']}")
897
- print(f"Time: {result['elapsed_time']:.2f}s")
898
-
899
- if result["error"]:
900
- print(f"Error: {result['error']}")
901
-
902
- # Save test results to JSON
903
- results_file = os.path.join(project_root, "test_results.json")
904
- try:
905
- with open(results_file, "w") as f:
906
- json.dump({
907
- "test_name": result['test_name'],
908
- "success": result['success'],
909
- "generation_success": result['generation_success'],
910
- "num_outputs": result['num_outputs'],
911
- "expected_outputs": result['expected_outputs'],
912
- "elapsed_time": result['elapsed_time'],
913
- "error": result['error'],
914
- }, f, indent=2)
915
- print(f"\n✓ Test results saved to: {results_file}")
916
- except Exception as e:
917
- print(f"\n⚠ Failed to save test results: {e}")
918
-
919
- # Exit with appropriate code
920
- print("\n" + "=" * 80)
921
- if result['success']:
922
- print("Test passed! ✓")
923
- print("=" * 80)
924
- exit(0)
925
- else:
926
- print("Test failed! ✗")
927
- print("=" * 80)
928
- exit(1)
 
7
  """
8
 
9
  import math
10
+ import os
11
+ import tempfile
12
  from typing import Optional, Union, List, Dict, Any, Tuple
13
  from dataclasses import dataclass, field, asdict
14
  from loguru import logger
15
+
16
+ from acestep.audio_utils import AudioSaver, generate_uuid_from_params
17
 
18
 
19
  @dataclass
20
+ class GenerationParams:
21
+ """Configuration for music generation parameters.
22
 
23
  Attributes:
24
  # Text Inputs
25
+ caption: A short text prompt describing the desired music (main prompt). < 512 characters
26
+ lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
27
+ instrumental: If True, generate instrumental music regardless of lyrics.
28
 
29
  # Music Metadata
30
+ bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
31
+ keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
32
+ timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
33
+ vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
34
+ duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
35
 
36
  # Generation Parameters
37
+ inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32100 for base model).
38
+ guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
39
+ seed: Integer seed for reproducibility. -1 means use random seed each time.
 
 
40
 
41
  # Advanced DiT Parameters
42
+ use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
43
+ cfg_interval_start: Start ratio (0.01.0) to apply CFG.
44
+ cfg_interval_end: End ratio (0.01.0) to apply CFG.
 
45
 
46
  # Task-Specific Parameters
47
+ task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
48
+ reference_audio: Path to a reference audio file for style transfer or cover tasks.
49
+ src_audio: Path to a source audio file for audio-to-audio tasks.
50
+ audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
51
+ repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
52
+ repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
53
+ audio_cover_strength: Strength of reference audio/codes influence (range 0.01.0). set smaller (0.2) for style transfer tasks.
54
+ instruction: Optional task instruction prompt. If empty, auto-generated by system.
55
 
56
+ # 5Hz Language Model Parameters for CoT reasoning
57
+ thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
58
+ lm_temperature: Sampling temperature for the LLM (0.02.0). Higher = more creative/varied results.
59
+ lm_cfg_scale: Classifier-free guidance scale for the LLM.
60
+ lm_top_k: LLM top-k sampling (0 = disabled).
61
+ lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
62
+ lm_negative_prompt: Negative prompt to use for LLM (for control).
63
+ use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
64
+ use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
65
+ use_cot_language: Whether to let LLM detect vocal language via CoT.
 
 
 
 
 
 
66
  """
67
+ # Required Inputs
68
+ task_type: str = "text2music"
69
+ instruction: str = "Fill the audio semantic mask based on the given conditions:"
70
+
71
+ # Audio Uploads
72
+ reference_audio: Optional[str] = None
73
+ src_audio: Optional[str] = None
74
+
75
+ # LM Codes Hints
76
+ audio_codes: str = ""
77
 
78
  # Text Inputs
79
  caption: str = ""
80
  lyrics: str = ""
81
+ instrumental: bool = False
82
 
83
+ # Metadata
 
 
 
84
  vocal_language: str = "unknown"
85
+ bpm: Optional[int] = None
86
+ keyscale: str = ""
87
+ timesignature: str = ""
88
+ duration: float = -1.0
89
+
90
+ # Advanced Settings
91
  inference_steps: int = 8
 
 
92
  seed: int = -1
93
+ guidance_scale: float = 7.0
 
 
94
  use_adg: bool = False
95
  cfg_interval_start: float = 0.0
96
  cfg_interval_end: float = 1.0
97
+
 
 
 
 
 
 
98
  repainting_start: float = 0.0
99
  repainting_end: float = -1
100
  audio_cover_strength: float = 1.0
 
101
 
102
  # 5Hz Language Model Parameters
103
+ thinking: bool = True
104
  lm_temperature: float = 0.85
105
  lm_cfg_scale: float = 2.0
106
  lm_top_k: int = 0
 
109
  use_cot_metas: bool = True
110
  use_cot_caption: bool = True
111
  use_cot_language: bool = True
 
 
112
 
113
+ def to_dict(self) -> Dict[str, Any]:
114
+ """Convert config to dictionary for JSON serialization."""
115
+ return asdict(self)
116
 
117
 
118
+ @dataclass
119
+ class GenerationConfig:
120
+ """Configuration for music generation.
121
+
122
+ Attributes:
123
+ batch_size: Number of audio samples to generate
124
+ allow_lm_batch: Whether to allow batch processing in LM
125
+ use_random_seed: Whether to use random seed
126
+ seed: Seed(s) for batch generation. Can be:
127
+ - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
128
+ - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
129
+ - int: Single seed value (will be converted to list and padded)
130
+ lm_batch_chunk_size: Batch chunk size for LM processing
131
+ is_format_caption: Whether to format caption
132
+ constrained_decoding_debug: Whether to enable constrained decoding debug
133
+ audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
134
+ """
135
+ batch_size: int = 2
136
+ allow_lm_batch: bool = False
137
+ use_random_seed: bool = True
138
+ seed: Optional[Union[int, List[int]]] = None
139
+ lm_batch_chunk_size: int = 8
140
+ is_format_caption: bool = False
141
+ use_constrained_decoding: bool = True
142
+ constrained_decoding_debug: bool = False
143
+ audio_format: str = "flac" # Default to FLAC for fast saving
144
+
145
  @dataclass
146
  class GenerationResult:
147
  """Result of music generation.
148
 
149
  Attributes:
150
  # Audio Outputs
151
+ audios: List of audio dictionaries with paths, keys, params
 
 
 
 
152
  generation_info: Markdown-formatted generation information
153
  status_message: Status message from generation
154
+ extra_outputs: Extra outputs from generation
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  success: Whether generation completed successfully
156
  error: Error message if generation failed
157
  """
158
 
159
  # Audio Outputs
160
+ audios: List[Dict[str, Any]] = field(default_factory=list)
 
 
 
161
  # Generation Information
162
  generation_info: str = ""
163
  status_message: str = ""
164
+ extra_outputs: Dict[str, Any] = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
165
  # Success Status
166
  success: bool = True
167
  error: Optional[str] = None
 
174
  def generate_music(
175
  dit_handler,
176
  llm_handler,
177
+ params: GenerationParams,
178
  config: GenerationConfig,
179
+ save_dir: Optional[str] = None,
180
  ) -> GenerationResult:
181
  """Generate music using ACE-Step model with optional LM reasoning.
182
 
 
 
 
 
183
  Args:
184
  dit_handler: Initialized DiT model handler (AceStepHandler instance)
185
  llm_handler: Initialized LLM handler (LLMHandler instance)
186
+ params: Generation parameters (GenerationParams instance)
187
  config: Generation configuration (GenerationConfig instance)
188
 
189
  Returns:
190
+ GenerationResult with generated audio files and metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  """
 
192
  try:
193
  # Phase 1: LM-based metadata and code generation (if enabled)
194
+ audio_code_string_to_use = params.audio_codes
195
  lm_generated_metadata = None
 
196
  lm_generated_audio_codes_list = []
197
 
198
  # Extract mutable copies of metadata (will be updated by LM if needed)
199
+ bpm = params.bpm
200
+ key_scale = params.keyscale
201
+ time_signature = params.timesignature
202
+ audio_duration = params.duration
203
 
204
+ # Determine if we need to generate audio codes
205
+ # If user has provided audio_codes, we don't need to generate them
206
+ # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
207
+ user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
208
+
209
+ # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
210
+ # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
211
+ # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
212
+ # Note: This logic can be refined based on specific requirements
213
+ need_audio_codes = not user_provided_audio_codes
214
 
215
+ # Determine if we should use chunk-based LM generation (always use chunks for consistency)
216
+ # Determine actual batch size for chunk processing
217
+ actual_batch_size = config.batch_size if config.batch_size is not None else 1
218
+
219
+ # Prepare seeds for batch generation
220
+ # Use config.seed if provided, otherwise fallback to params.seed
221
+ # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
222
+ seed_for_generation = params.seed # Default fallback
223
+ if config.seed is not None:
224
+ if isinstance(config.seed, list):
225
+ # Convert List[int] to comma-separated string
226
+ seed_for_generation = ",".join(str(s) for s in config.seed)
227
+ elif isinstance(config.seed, int):
228
+ # Single int seed
229
+ seed_for_generation = config.seed
230
+
231
+ # Use dit_handler.prepare_seeds to handle seed list generation and padding
232
+ # This will handle all the logic: padding with random seeds if needed, etc.
233
+ actual_seed_list, _ = dit_handler.prepare_seeds(
234
+ actual_batch_size, seed_for_generation, config.use_random_seed
235
+ )
236
+
237
  # LM-based Chain-of-Thought reasoning
238
+ if params.thinking and llm_handler.llm_initialized and params.use_cot_metas:
239
  # Convert sampling parameters
240
+ top_k_value = None if params.lm_top_k == 0 else int(params.lm_top_k)
241
+ top_p_value = None if params.lm_top_p >= 1.0 else params.lm_top_p
242
 
243
  # Build user_metadata from user-provided values
244
  user_metadata = {}
 
270
 
271
  user_metadata_to_pass = user_metadata if user_metadata else None
272
 
273
+ # Determine infer_type based on whether we need audio codes
274
+ # - "llm_dit": generates both metas and audio codes (two-phase internally)
275
+ # - "dit": generates only metas (single phase)
276
+ infer_type = "llm_dit" if need_audio_codes else "dit"
277
+
278
+ # Use chunk size from config, or default to batch_size if not set
279
+ max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
280
+ num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
281
+
282
+ all_metadata_list = []
283
+ all_audio_codes_list = []
284
+
285
+ for chunk_idx in range(num_chunks):
286
+ chunk_start = chunk_idx * max_inference_batch_size
287
+ chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
288
+ chunk_size = chunk_end - chunk_start
289
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
+ logger.info(
292
+ f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
293
+ f"(size: {chunk_size}, seeds: {chunk_seeds})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
 
 
295
 
296
+ # Use the determined infer_type
297
+ # - "llm_dit" will internally run two phases (metas + codes)
298
+ # - "dit" will only run phase 1 (metas only)
299
+ result = llm_handler.generate_with_stop_condition(
300
+ caption=params.caption or "",
301
+ lyrics=params.lyrics or "",
302
+ infer_type=infer_type,
303
+ temperature=params.lm_temperature,
304
+ cfg_scale=params.lm_cfg_scale,
305
+ negative_prompt=params.lm_negative_prompt,
306
  top_k=top_k_value,
307
  top_p=top_p_value,
308
  user_metadata=user_metadata_to_pass,
309
+ use_cot_caption=params.use_cot_caption,
310
+ use_cot_language=params.use_cot_language,
311
  is_format_caption=config.is_format_caption,
312
+ use_constrained_decoding=config.use_constrained_decoding,
313
  constrained_decoding_debug=config.constrained_decoding_debug,
314
+ batch_size=chunk_size,
315
+ seeds=chunk_seeds,
316
  )
 
 
317
 
318
+ if chunk_size > 1:
319
+ metadata_list, audio_codes_list, status = result
320
+ all_metadata_list.extend(metadata_list)
321
+ all_audio_codes_list.extend(audio_codes_list)
322
+ else:
323
+ metadata, audio_codes, status = result
324
+ all_metadata_list.append(metadata)
325
+ all_audio_codes_list.append(audio_codes)
326
+
327
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
328
+ lm_generated_audio_codes_list = all_audio_codes_list
329
+
330
+ # Set audio_code_string_to_use based on infer_type
331
+ if infer_type == "llm_dit":
332
+ # If batch mode, use list; otherwise use single string
333
+ if actual_batch_size > 1:
334
+ audio_code_string_to_use = all_audio_codes_list
335
+ else:
336
+ audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
337
+ else:
338
+ # For "dit" mode, keep user-provided codes or empty
339
+ audio_code_string_to_use = params.audio_codes
340
+
341
+ # Update metadata from LM if not provided by user
342
+ if lm_generated_metadata:
343
+ bpm, key_scale, time_signature, audio_duration = _update_metadata_from_lm(
344
+ lm_generated_metadata, bpm, key_scale, time_signature, audio_duration
345
+ )
346
+
347
 
348
  # Phase 2: DiT music generation
349
+ # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
350
  result = dit_handler.generate_music(
351
+ captions=params.caption,
352
+ lyrics=params.lyrics,
353
  bpm=bpm,
354
  key_scale=key_scale,
355
  time_signature=time_signature,
356
+ vocal_language=params.vocal_language,
357
+ inference_steps=params.inference_steps,
358
+ guidance_scale=params.guidance_scale,
359
  use_random_seed=config.use_random_seed,
360
+ seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
361
+ reference_audio=params.reference_audio,
362
  audio_duration=audio_duration,
363
+ batch_size=config.batch_size if config.batch_size is not None else 1,
364
+ src_audio=params.src_audio,
365
  audio_code_string=audio_code_string_to_use,
366
+ repainting_start=params.repainting_start,
367
+ repainting_end=params.repainting_end,
368
+ instruction=params.instruction,
369
+ audio_cover_strength=params.audio_cover_strength,
370
+ task_type=params.task_type,
371
+ use_adg=params.use_adg,
372
+ cfg_interval_start=params.cfg_interval_start,
373
+ cfg_interval_end=params.cfg_interval_end,
 
 
374
  )
375
 
376
+ # Check if generation failed
377
+ if not result.get("success", False):
378
+ return GenerationResult(
379
+ audios=[],
380
+ generation_info=result.get("generation_info", ""),
381
+ status_message=result.get("status_message", ""),
382
+ extra_outputs={},
383
+ success=False,
384
+ error=result.get("error"),
385
+ )
386
+
387
+ # Extract results from dit_handler.generate_music dict
388
+ dit_audios = result.get("audios", [])
389
+ generation_info = result.get("generation_info", "")
390
+ status_message = result.get("status_message", "")
391
+ dit_extra_outputs = result.get("extra_outputs", {})
392
 
393
  # Append LM metadata to generation info
394
  if lm_generated_metadata:
395
  generation_info = _append_lm_metadata_to_info(generation_info, lm_generated_metadata)
396
 
397
+ # Use the seed list already prepared above (from config.seed or params.seed fallback)
398
+ # actual_seed_list was computed earlier using dit_handler.prepare_seeds
399
+ seed_list = actual_seed_list
400
+
401
+ # Get base params dictionary
402
+ base_params_dict = params.to_dict()
403
+
404
+ # Save audio files using AudioSaver (format from config)
405
+ audio_format = config.audio_format if config.audio_format else "flac"
406
+ audio_saver = AudioSaver(default_format=audio_format)
407
+
408
+ # Use handler's temp_dir for saving files
409
+ if save_dir is not None:
410
+ os.makedirs(save_dir, exist_ok=True)
411
+
412
+ # Build audios list for GenerationResult with params and save files
413
+ # Audio saving and UUID generation handled here, outside of handler
414
+ audios = []
415
+ for idx, dit_audio in enumerate(dit_audios):
416
+ # Create a copy of params dict for this audio
417
+ audio_params = base_params_dict.copy()
418
+
419
+ # Update audio-specific values
420
+ audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
421
+
422
+ # Add audio codes if batch mode
423
+ if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
424
+ audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
425
+
426
+ # Get audio tensor and metadata
427
+ audio_tensor = dit_audio.get("tensor")
428
+ sample_rate = dit_audio.get("sample_rate", 48000)
429
+
430
+ # Generate UUID for this audio (moved from handler)
431
+ batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
432
+ audio_code_str = lm_generated_audio_codes_list[idx] if (lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
433
+ if isinstance(audio_code_str, list):
434
+ audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
435
+
436
+ audio_key = generate_uuid_from_params(
437
+ captions=params.caption,
438
+ lyrics=params.lyrics,
439
+ bpm=bpm,
440
+ key_scale=key_scale,
441
+ time_signature=time_signature,
442
+ vocal_language=params.vocal_language,
443
+ inference_steps=params.inference_steps,
444
+ guidance_scale=params.guidance_scale,
445
+ seed=batch_seed,
446
+ audio_duration=audio_duration,
447
+ audio_code_string=audio_code_str,
448
+ repainting_start=params.repainting_start,
449
+ repainting_end=params.repainting_end,
450
+ instruction=params.instruction,
451
+ audio_cover_strength=params.audio_cover_strength,
452
+ task_type=params.task_type,
453
+ use_adg=params.use_adg,
454
+ cfg_interval_start=params.cfg_interval_start,
455
+ cfg_interval_end=params.cfg_interval_end,
456
+ audio_format=audio_format,
457
+ reference_audio=params.reference_audio,
458
+ src_audio=params.src_audio,
459
+ batch_index=idx,
460
+ )
461
+
462
+ # Save audio file (handled outside handler)
463
+ audio_path = None
464
+ if audio_tensor is not None and save_dir is not None:
465
+ try:
466
+ audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
467
+ audio_path = audio_saver.save_audio(
468
+ audio_tensor,
469
+ audio_file,
470
+ sample_rate=sample_rate,
471
+ format=audio_format,
472
+ channels_first=True
473
+ )
474
+ except Exception as e:
475
+ logger.error(f"[generate_music] Failed to save audio file: {e}")
476
+ audio_path = "" # Fallback to empty path
477
+
478
+ audio_dict = {
479
+ "path": audio_path or "", # File path (saved here, not in handler)
480
+ "tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
481
+ "key": audio_key,
482
+ "sample_rate": sample_rate,
483
+ "params": audio_params,
484
+ }
485
+
486
+ audios.append(audio_dict)
487
+
488
+ # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
489
+ extra_outputs = dit_extra_outputs.copy()
490
+ extra_outputs["lm_metadata"] = lm_generated_metadata
491
+
492
+ # Create and return GenerationResult
493
  return GenerationResult(
494
+ audios=audios,
 
 
495
  generation_info=generation_info,
496
  status_message=status_message,
497
+ extra_outputs=extra_outputs,
 
 
 
 
 
 
 
498
  success=True,
499
  error=None,
500
  )
 
502
  except Exception as e:
503
  logger.exception("Music generation failed")
504
  return GenerationResult(
505
+ audios=[],
 
506
  generation_info=f"❌ Generation failed: {str(e)}",
507
  status_message=f"Error: {str(e)}",
508
+ extra_outputs={},
509
+ success=False,
510
+ error=str(e),
511
  )
512
 
513
 
 
577
  # LEGACY GRADIO UI COMPATIBILITY LAYER
578
  # ============================================================================
579
 
580
+ def generate_for_gradio(
581
  dit_handler,
582
  llm_handler,
583
  captions,
 
627
  Tuple with 28 elements for Gradio UI component updates
628
  """
629
 
630
+ # Convert legacy parameters to GenerationParams and GenerationConfig
631
+ params = GenerationParams(
632
  caption=captions,
633
  lyrics=lyrics,
634
  bpm=bpm,
635
+ keyscale=key_scale,
636
+ timesignature=time_signature,
637
  vocal_language=vocal_language,
638
+ audio_codes=text2music_audio_code_string,
639
+ duration=audio_duration,
640
  inference_steps=inference_steps,
641
  guidance_scale=guidance_scale,
 
642
  seed=seed,
 
643
  use_adg=use_adg,
644
  cfg_interval_start=cfg_interval_start,
645
  cfg_interval_end=cfg_interval_end,
 
647
  task_type=task_type,
648
  reference_audio=reference_audio,
649
  src_audio=src_audio,
 
650
  repainting_start=repainting_start,
651
  repainting_end=repainting_end,
652
  audio_cover_strength=audio_cover_strength,
653
  instruction=instruction_display_gen,
654
+ thinking=think_checkbox,
655
  lm_temperature=lm_temperature,
656
  lm_cfg_scale=lm_cfg_scale,
657
  lm_top_k=lm_top_k,
 
660
  use_cot_metas=use_cot_metas,
661
  use_cot_caption=use_cot_caption,
662
  use_cot_language=use_cot_language,
 
 
 
 
663
  )
664
 
665
+ config = GenerationConfig(batch_size=1)
666
+ config.batch_size = batch_size_input
667
+ config.use_random_seed = random_seed_checkbox
668
+ config.allow_lm_batch = allow_lm_batch
669
+ config.lm_batch_chunk_size = lm_batch_chunk_size
670
+ config.is_format_caption = is_format_caption
671
+ config.constrained_decoding_debug = constrained_decoding_debug
672
+
673
  # Call new API
674
+ result = generate_music(dit_handler, llm_handler, params, config)
675
+
676
+ # Extract audio paths from result.audios
677
+ audio_paths = [audio["path"] for audio in result.audios]
678
+
679
+ # Extract extra outputs
680
+ extra_outputs = result.extra_outputs
681
+ seed_value = extra_outputs.get("seed_value", "")
682
+ lm_metadata = extra_outputs.get("lm_metadata", None)
683
+
684
+ # Legacy alignment fields (no longer used, set to empty/None)
685
+ align_score_1 = ""
686
+ align_text_1 = ""
687
+ align_plot_1 = None
688
+ align_score_2 = ""
689
+ align_text_2 = ""
690
+ align_plot_2 = None
691
 
692
  # Determine which codes to update in UI
693
+ if config.allow_lm_batch and lm_metadata:
694
  # Batch mode: extract codes from metadata if available
695
+ lm_codes_list = lm_metadata.get('audio_codes_list', [])
696
  updated_audio_codes = lm_codes_list[0] if lm_codes_list else text2music_audio_code_string
697
  codes_outputs = (lm_codes_list + [""] * 8)[:8]
698
  else:
699
  # Single mode
700
+ lm_codes = lm_metadata.get('audio_codes', '') if lm_metadata else ''
701
  updated_audio_codes = lm_codes if lm_codes else text2music_audio_code_string
702
  codes_outputs = [""] * 8
703
 
704
  # Prepare audio outputs (up to 8)
705
+ audio_outputs = (audio_paths + [None] * 8)[:8]
706
 
707
  # Return tuple for Gradio UI (28 elements)
708
  return (
 
714
  audio_outputs[5], # generated_audio_6
715
  audio_outputs[6], # generated_audio_7
716
  audio_outputs[7], # generated_audio_8
717
+ audio_paths, # generated_audio_batch
718
  result.generation_info,
719
  result.status_message,
720
+ seed_value,
721
+ align_score_1,
722
+ align_text_1,
723
+ align_plot_1,
724
+ align_score_2,
725
+ align_text_2,
726
+ align_plot_2,
727
  updated_audio_codes, # Update main audio codes in UI
728
  codes_outputs[0], # text2music_audio_code_string_1
729
  codes_outputs[1], # text2music_audio_code_string_2
 
733
  codes_outputs[5], # text2music_audio_code_string_6
734
  codes_outputs[6], # text2music_audio_code_string_7
735
  codes_outputs[7], # text2music_audio_code_string_8
736
+ lm_metadata, # Store metadata for "Send to src audio" buttons
737
  is_format_caption, # Keep is_format_caption unchanged
738
  )
739
 
740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acestep/llm_inference.py CHANGED
@@ -5,7 +5,7 @@ Handles all LM-related operations including initialization and generation
5
  import os
6
  import traceback
7
  import time
8
- from typing import Optional, Dict, Any, Tuple, List
9
  from contextlib import contextmanager
10
 
11
  import yaml
@@ -85,6 +85,189 @@ class LLMHandler:
85
  except Exception as e:
86
  return 0.9, False
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def initialize(
89
  self,
90
  checkpoint_dir: str,
@@ -150,41 +333,21 @@ class LLMHandler:
150
  # vllm initialization failed, fallback to PyTorch
151
  if not self.llm_initialized:
152
  logger.warning("vllm initialization failed, falling back to PyTorch backend")
153
- try:
154
- self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
155
- if not self.offload_to_cpu:
156
- self.llm = self.llm.to(device).to(self.dtype)
157
- else:
158
- self.llm = self.llm.to("cpu").to(self.dtype)
159
- self.llm.eval()
160
- self.llm_backend = "pt"
161
- self.llm_initialized = True
162
- logger.info("5Hz LM initialized successfully using PyTorch backend (fallback)")
163
- status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
164
- except Exception as e:
165
- return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
166
  # If vllm initialization succeeded, self.llm_initialized should already be True
167
  else:
168
  # Use PyTorch backend (pt)
169
- try:
170
- self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
171
- if not self.offload_to_cpu:
172
- self.llm = self.llm.to(device).to(self.dtype)
173
- else:
174
- self.llm = self.llm.to("cpu").to(self.dtype)
175
- self.llm.eval()
176
- self.llm_backend = "pt"
177
- self.llm_initialized = True
178
- logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
179
- status_msg = f"✅ 5Hz LM initialized successfully\nModel: {full_lm_model_path}\nBackend: PyTorch\nDevice: {device}"
180
- except Exception as e:
181
- return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
182
 
183
  return status_msg, True
184
 
185
  except Exception as e:
186
- error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
187
- return error_msg, False
188
 
189
  def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
190
  """Initialize 5Hz LM model using vllm backend"""
@@ -230,12 +393,11 @@ class LLMHandler:
230
  return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
231
  except Exception as e:
232
  self.llm_initialized = False
233
- error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
234
- return error_msg
235
 
236
- def _run_vllm_from_formatted(
237
  self,
238
- formatted_prompt: str,
239
  temperature: float,
240
  cfg_scale: float,
241
  negative_prompt: str,
@@ -244,7 +406,7 @@ class LLMHandler:
244
  repetition_penalty: float,
245
  use_constrained_decoding: bool = True,
246
  constrained_decoding_debug: bool = False,
247
- metadata_temperature: Optional[float] = 0.85,
248
  codes_temperature: Optional[float] = None,
249
  target_duration: Optional[float] = None,
250
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
@@ -256,37 +418,40 @@ class LLMHandler:
256
  caption: str = "",
257
  lyrics: str = "",
258
  cot_text: str = "",
259
- ) -> str:
260
- """Shared vllm path: accept prebuilt formatted prompt and return text."""
 
 
 
 
 
261
  from nanovllm import SamplingParams
262
 
 
 
 
 
263
  # Determine effective temperature for sampler
264
- use_phase_temperatures = metadata_temperature is not None or codes_temperature is not None
 
 
265
  effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
266
 
267
- # Use shared constrained processor if enabled
268
- constrained_processor = None
269
- if use_constrained_decoding or use_phase_temperatures:
270
- # Reset processor state for new generation
271
- self.constrained_processor.reset()
272
-
273
- # Use shared processor, just update caption and settings
274
- self.constrained_processor.enabled = use_constrained_decoding
275
- self.constrained_processor.debug = constrained_decoding_debug
276
- self.constrained_processor.metadata_temperature = metadata_temperature if use_phase_temperatures else None
277
- self.constrained_processor.codes_temperature = codes_temperature if use_phase_temperatures else None
278
- self.constrained_processor.set_target_duration(target_duration)
279
- # Always call set_user_metadata to ensure previous settings are cleared if None
280
- self.constrained_processor.set_user_metadata(user_metadata)
281
- self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
282
- # Set skip_caption and skip_language based on flags
283
- self.constrained_processor.set_skip_genres(skip_genres)
284
- self.constrained_processor.set_skip_caption(skip_caption)
285
- self.constrained_processor.set_skip_language(skip_language)
286
- # Set generation phase for phase-aware processing
287
- self.constrained_processor.set_generation_phase(generation_phase)
288
-
289
- constrained_processor = self.constrained_processor
290
 
291
  sampling_params = SamplingParams(
292
  max_tokens=self.max_model_len - 64,
@@ -301,119 +466,25 @@ class LLMHandler:
301
 
302
  if cfg_scale > 1.0:
303
  # Build unconditional prompt based on generation phase
304
- if generation_phase == "codes":
305
- # Codes phase: use empty CoT in unconditional prompt
306
- # formatted_prompt was built with build_formatted_prompt_with_cot(caption, lyrics, cot_text)
307
- # For unconditional, we use empty CoT: build_formatted_prompt_with_cot(caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=...)
308
- formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
309
- caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
310
- )
311
- else:
312
- # CoT phase: unconditional prompt
313
- # If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
314
- formatted_unconditional_prompt = self.build_formatted_prompt(
315
- caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
316
- )
317
-
318
- outputs = self.llm.generate(
319
- [formatted_prompt],
320
- sampling_params,
321
- unconditional_prompts=[formatted_unconditional_prompt],
322
- )
323
- else:
324
- outputs = self.llm.generate([formatted_prompt], sampling_params)
325
-
326
- # Extract text (retain original selection order/logic)
327
- if isinstance(outputs, list) and len(outputs) > 0:
328
- if hasattr(outputs[0], "outputs") and len(outputs[0].outputs) > 0:
329
- output_text = outputs[0].outputs[0].text
330
- elif hasattr(outputs[0], "text"):
331
- output_text = outputs[0].text
332
- elif isinstance(outputs[0], dict) and "text" in outputs[0]:
333
- output_text = outputs[0]["text"]
334
- else:
335
- output_text = str(outputs[0])
336
- else:
337
- output_text = str(outputs)
338
-
339
- return output_text
340
-
341
- def _run_vllm_batch(
342
- self,
343
- formatted_prompts: List[str],
344
- temperature: float,
345
- cfg_scale: float,
346
- negative_prompt: str,
347
- top_k: Optional[int],
348
- top_p: Optional[float],
349
- repetition_penalty: float,
350
- use_constrained_decoding: bool = True,
351
- constrained_decoding_debug: bool = False,
352
- target_duration: Optional[float] = None,
353
- generation_phase: str = "codes",
354
- caption: str = "",
355
- lyrics: str = "",
356
- cot_text: str = "",
357
- seeds: Optional[List[int]] = None,
358
- ) -> List[str]:
359
- """Batch generation using vllm backend"""
360
- from nanovllm import SamplingParams
361
-
362
- batch_size = len(formatted_prompts)
363
-
364
- # Determine effective temperature for sampler
365
- effective_sampler_temp = temperature
366
-
367
- # Use shared constrained processor if enabled
368
- # Note: vllm batch mode uses same processor for all items
369
- constrained_processor = None
370
- if use_constrained_decoding:
371
- # Reset processor state for new generation
372
- self.constrained_processor.reset()
373
-
374
- self.constrained_processor.enabled = use_constrained_decoding
375
- self.constrained_processor.debug = constrained_decoding_debug
376
- self.constrained_processor.metadata_temperature = None
377
- self.constrained_processor.codes_temperature = None
378
- self.constrained_processor.set_target_duration(target_duration)
379
- self.constrained_processor.set_user_metadata(None)
380
- self.constrained_processor.set_stop_at_reasoning(False)
381
- self.constrained_processor.set_skip_genres(True)
382
- self.constrained_processor.set_skip_caption(True)
383
- self.constrained_processor.set_skip_language(True)
384
- self.constrained_processor.set_generation_phase(generation_phase)
385
-
386
- constrained_processor = self.constrained_processor
387
-
388
- # Build sampling params
389
- sampling_params = SamplingParams(
390
- max_tokens=self.max_model_len - 64,
391
- temperature=effective_sampler_temp,
392
- cfg_scale=cfg_scale,
393
- top_k=top_k,
394
- top_p=top_p,
395
- repetition_penalty=repetition_penalty,
396
- logits_processor=constrained_processor,
397
- logits_processor_update_state=constrained_processor.update_state if constrained_processor else None,
398
- )
399
-
400
- # Generate with or without CFG
401
- if cfg_scale > 1.0:
402
- # Build unconditional prompts
403
- formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
404
- caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
405
  )
406
  unconditional_prompts = [formatted_unconditional_prompt] * batch_size
407
 
408
  outputs = self.llm.generate(
409
- formatted_prompts,
410
  sampling_params,
411
  unconditional_prompts=unconditional_prompts,
412
  )
413
  else:
414
- outputs = self.llm.generate(formatted_prompts, sampling_params)
415
-
416
- # Extract text from each output
417
  output_texts = []
418
  for output in outputs:
419
  if hasattr(output, "outputs") and len(output.outputs) > 0:
@@ -424,70 +495,11 @@ class LLMHandler:
424
  output_texts.append(output["text"])
425
  else:
426
  output_texts.append(str(output))
427
-
428
- return output_texts
429
 
430
- def _run_pt_batch(
431
- self,
432
- formatted_prompts: List[str],
433
- temperature: float,
434
- cfg_scale: float,
435
- negative_prompt: str,
436
- top_k: Optional[int],
437
- top_p: Optional[float],
438
- repetition_penalty: float,
439
- use_constrained_decoding: bool = True,
440
- constrained_decoding_debug: bool = False,
441
- target_duration: Optional[float] = None,
442
- generation_phase: str = "codes",
443
- caption: str = "",
444
- lyrics: str = "",
445
- cot_text: str = "",
446
- seeds: Optional[List[int]] = None,
447
- ) -> List[str]:
448
- """Batch generation using PyTorch backend"""
449
- import random
450
-
451
- batch_size = len(formatted_prompts)
452
- output_texts = []
453
-
454
- # Generate each item sequentially with different seeds
455
- # (PyTorch backend doesn't support true batching efficiently)
456
- for i, formatted_prompt in enumerate(formatted_prompts):
457
- # Set seed for this item if provided
458
- if seeds and i < len(seeds):
459
- torch.manual_seed(seeds[i])
460
- if torch.cuda.is_available():
461
- torch.cuda.manual_seed_all(seeds[i])
462
-
463
- # Generate using single-item method
464
- output_text = self._run_pt_from_formatted(
465
- formatted_prompt=formatted_prompt,
466
- temperature=temperature,
467
- cfg_scale=cfg_scale,
468
- negative_prompt=negative_prompt,
469
- top_k=top_k,
470
- top_p=top_p,
471
- repetition_penalty=repetition_penalty,
472
- use_constrained_decoding=use_constrained_decoding,
473
- constrained_decoding_debug=constrained_decoding_debug,
474
- target_duration=target_duration,
475
- user_metadata=None,
476
- stop_at_reasoning=False,
477
- skip_genres=True,
478
- skip_caption=True,
479
- skip_language=True,
480
- generation_phase=generation_phase,
481
- caption=caption,
482
- lyrics=lyrics,
483
- cot_text=cot_text,
484
- )
485
-
486
- output_texts.append(output_text)
487
-
488
- return output_texts
489
 
490
- def _run_pt_from_formatted(
491
  self,
492
  formatted_prompt: str,
493
  temperature: float,
@@ -496,20 +508,20 @@ class LLMHandler:
496
  top_k: Optional[int],
497
  top_p: Optional[float],
498
  repetition_penalty: float,
499
- use_constrained_decoding: bool = True,
500
- constrained_decoding_debug: bool = False,
501
- target_duration: Optional[float] = None,
502
- user_metadata: Optional[Dict[str, Optional[str]]] = None,
503
- stop_at_reasoning: bool = False,
504
- skip_genres: bool = True,
505
- skip_caption: bool = False,
506
- skip_language: bool = False,
507
- generation_phase: str = "cot",
508
- caption: str = "",
509
- lyrics: str = "",
510
- cot_text: str = "",
511
  ) -> str:
512
- """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
513
  inputs = self.llm_tokenizer(
514
  formatted_prompt,
515
  return_tensors="pt",
@@ -517,27 +529,19 @@ class LLMHandler:
517
  truncation=True,
518
  )
519
 
520
- # Use shared constrained processor if enabled
521
- constrained_processor = None
522
- if use_constrained_decoding:
523
- # Reset processor state for new generation
524
- self.constrained_processor.reset()
525
-
526
- # Use shared processor, just update caption and settings
527
- self.constrained_processor.enabled = use_constrained_decoding
528
- self.constrained_processor.debug = constrained_decoding_debug
529
- self.constrained_processor.set_target_duration(target_duration)
530
- # Always call set_user_metadata to ensure previous settings are cleared if None
531
- self.constrained_processor.set_user_metadata(user_metadata)
532
- self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
533
- # Set skip_caption and skip_language based on flags
534
- self.constrained_processor.set_skip_genres(skip_genres)
535
- self.constrained_processor.set_skip_caption(skip_caption)
536
- self.constrained_processor.set_skip_language(skip_language)
537
- # Set generation phase for phase-aware processing
538
- self.constrained_processor.set_generation_phase(generation_phase)
539
-
540
- constrained_processor = self.constrained_processor
541
 
542
  with self._load_model_context():
543
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -546,25 +550,18 @@ class LLMHandler:
546
  max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
547
 
548
  # Build logits processor list (only for CFG and repetition penalty)
549
- logits_processor = LogitsProcessorList()
550
-
551
- # Add repetition penalty if needed
552
- if repetition_penalty != 1.0:
553
- logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
554
 
555
  if cfg_scale > 1.0:
556
  # Build unconditional prompt based on generation phase
557
- if generation_phase == "codes":
558
- # Codes phase: use empty CoT in unconditional prompt
559
- formatted_unconditional_prompt = self.build_formatted_prompt_with_cot(
560
- caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
561
- )
562
- else:
563
- # CoT phase: unconditional prompt
564
- # If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
565
- formatted_unconditional_prompt = self.build_formatted_prompt(
566
- caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
567
- )
568
 
569
  # Tokenize both prompts together to ensure same length (with left padding)
570
  # Left padding is important for generation tasks
@@ -657,7 +654,101 @@ class LLMHandler:
657
 
658
  output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
659
  return output_text
660
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
662
  """Check if all required metadata are present."""
663
  if user_metadata is None:
@@ -708,7 +799,9 @@ class LLMHandler:
708
  use_cot_caption: bool = True,
709
  use_cot_language: bool = True,
710
  is_format_caption: bool = False,
711
- ) -> Tuple[Dict[str, Any], str, str]:
 
 
712
  """Two-phase LM generation: CoT generation followed by audio codes generation.
713
 
714
  - infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
@@ -721,30 +814,56 @@ class LLMHandler:
721
  If specified, constrained decoding will inject these values directly.
722
  use_cot_caption: Whether to generate caption in CoT (default True).
723
  use_cot_language: Whether to generate language in CoT (default True).
 
 
 
 
 
 
 
 
724
  """
725
  import time
 
726
 
727
  infer_type = (infer_type or "").strip().lower()
728
  if infer_type not in {"dit", "llm_dit"}:
 
 
729
  return {}, "", f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
730
-
 
 
 
 
 
731
  metadata = {}
732
  audio_codes = ""
733
  has_all_metas = self.has_all_metas(user_metadata)
734
-
735
- # Timing variables
736
  phase1_time = 0.0
737
  phase2_time = 0.0
738
 
 
 
 
 
 
 
 
 
 
739
  # ========== PHASE 1: CoT Generation ==========
740
- # Always generate CoT unless all metadata are user-provided
741
- if not has_all_metas or not is_format_caption:
742
- logger.info("Phase 1: Generating CoT metadata...")
 
 
 
743
  phase1_start = time.time()
744
 
745
  # Build formatted prompt for CoT phase
746
  formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
747
-
748
  logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
749
  # Generate CoT (stop at </think>)
750
  cot_output_text, status = self.generate_from_formatted_prompt(
@@ -774,23 +893,39 @@ class LLMHandler:
774
  phase1_time = time.time() - phase1_start
775
 
776
  if not cot_output_text:
 
 
777
  return {}, "", status
778
 
779
  # Parse metadata from CoT output
780
  metadata, _ = self.parse_lm_output(cot_output_text)
781
- logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
 
 
 
782
  else:
783
  # Use user-provided metadata
784
- logger.info("Phase 1: Using user-provided metadata (skipping generation)")
 
 
 
785
  metadata = {k: v for k, v in user_metadata.items() if v is not None}
786
 
787
  # If infer_type is 'dit', stop here and return only metadata
788
  if infer_type == "dit":
789
- status_msg = f"✅ Generated CoT metadata successfully\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
790
- return metadata, "", status_msg
 
 
 
 
 
791
 
792
  # ========== PHASE 2: Audio Codes Generation ==========
793
- logger.info("Phase 2: Generating audio codes...")
 
 
 
794
  phase2_start = time.time()
795
 
796
  # Format metadata as CoT using YAML (matching training format)
@@ -799,221 +934,110 @@ class LLMHandler:
799
  # Build formatted prompt with CoT for codes generation phase
800
  formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
801
  logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
802
- # Generate audio codes
803
- codes_output_text, status = self.generate_from_formatted_prompt(
804
- formatted_prompt=formatted_prompt_with_cot,
805
- cfg={
806
- "temperature": temperature,
807
- "cfg_scale": cfg_scale,
808
- "negative_prompt": negative_prompt,
809
- "top_k": top_k,
810
- "top_p": top_p,
811
- "repetition_penalty": repetition_penalty,
812
- "target_duration": target_duration,
813
- "user_metadata": None, # No user metadata injection in Phase 2
814
- "skip_caption": True, # Skip caption since CoT is already included
815
- "skip_language": True, # Skip language since CoT is already included
816
- "generation_phase": "codes",
817
- # Pass context for building unconditional prompt in codes phase
818
- "caption": caption,
819
- "lyrics": lyrics,
820
- "cot_text": cot_text,
821
- },
822
- use_constrained_decoding=use_constrained_decoding,
823
- constrained_decoding_debug=constrained_decoding_debug,
824
- stop_at_reasoning=False, # Generate codes until EOS
825
- )
826
-
827
- if not codes_output_text:
828
- return metadata, "", status
829
-
830
- phase2_time = time.time() - phase2_start
831
-
832
- # Parse audio codes from output (metadata should be same as Phase 1)
833
- _, audio_codes = self.parse_lm_output(codes_output_text)
834
-
835
- codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
836
- logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
837
 
838
- status_msg = f"✅ Generated successfully (2-phase)\nPhase 1: CoT metadata\nPhase 2: {codes_count} audio codes\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
839
- return metadata, audio_codes, status_msg
840
-
841
- def generate_with_stop_condition_batch(
842
- self,
843
- caption: str,
844
- lyrics: str,
845
- batch_size: int,
846
- infer_type: str = "llm_dit",
847
- temperature: float = 0.85,
848
- cfg_scale: float = 1.0,
849
- negative_prompt: str = "NO USER INPUT",
850
- top_k: Optional[int] = None,
851
- top_p: Optional[float] = None,
852
- repetition_penalty: float = 1.0,
853
- use_constrained_decoding: bool = True,
854
- constrained_decoding_debug: bool = False,
855
- target_duration: Optional[float] = None,
856
- user_metadata: Optional[Dict[str, Optional[str]]] = None,
857
- use_cot_caption: bool = True,
858
- use_cot_language: bool = True,
859
- is_format_caption: bool = False,
860
- seeds: Optional[List[int]] = None,
861
- ) -> Tuple[List[Dict[str, Any]], List[str], str]:
862
- """
863
- Batch version of generate_with_stop_condition.
864
-
865
- Generates multiple audio codes with same conditions but different seeds (for diversity).
866
-
867
- Args:
868
- caption: Same caption for all items
869
- lyrics: Same lyrics for all items
870
- batch_size: Number of items to generate
871
- seeds: Optional list of seeds for each batch item (for reproducibility)
872
- ... (other args same as generate_with_stop_condition)
873
-
874
- Returns:
875
- Tuple of (metadata_list, audio_codes_list, status_message)
876
- - metadata_list: List of metadata dicts (same metadata for all items)
877
- - audio_codes_list: List of audio code strings (one per item, different due to sampling)
878
- - status_message: Generation status
879
- """
880
- import random
881
- import time
882
-
883
- infer_type = (infer_type or "").strip().lower()
884
- if infer_type not in {"dit", "llm_dit"}:
885
- return [], [], f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
886
-
887
- # Generate seeds if not provided
888
- if seeds is None:
889
- seeds = [random.randint(0, 2**32 - 1) for _ in range(batch_size)]
890
- elif len(seeds) < batch_size:
891
- # Pad with random seeds if not enough provided
892
- seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(batch_size - len(seeds))]
893
- else:
894
- seeds = seeds[:batch_size] # Truncate if too many
895
-
896
- # Timing variables
897
- phase1_time = 0.0
898
- phase2_time = 0.0
899
-
900
- # ========== PHASE 1: CoT Generation (ONCE for all items) ==========
901
- has_all_metas = self.has_all_metas(user_metadata)
902
-
903
- if not has_all_metas or not is_format_caption:
904
- logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
905
- phase1_start = time.time()
906
 
907
- # Generate CoT metadata once (same for all batch items)
908
- metadata, _, status = self.generate_with_stop_condition(
909
- caption=caption,
910
- lyrics=lyrics,
911
- infer_type="dit", # Only generate metadata
912
- temperature=temperature,
913
- cfg_scale=cfg_scale,
914
- negative_prompt=negative_prompt,
915
- top_k=top_k,
916
- top_p=top_p,
917
- repetition_penalty=repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
  use_constrained_decoding=use_constrained_decoding,
919
  constrained_decoding_debug=constrained_decoding_debug,
920
- target_duration=target_duration,
921
- user_metadata=user_metadata,
922
- use_cot_caption=use_cot_caption,
923
- use_cot_language=use_cot_language,
924
- is_format_caption=is_format_caption,
925
  )
926
 
927
- phase1_time = time.time() - phase1_start
 
928
 
929
- if not metadata:
930
- return [], [], status
931
 
932
- logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
933
- else:
934
- # Use user-provided metadata
935
- logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)")
936
- metadata = {k: v for k, v in user_metadata.items() if v is not None}
937
-
938
- # If infer_type is 'dit', stop here and return only metadata
939
- if infer_type == "dit":
940
- metadata_list = [metadata.copy() for _ in range(batch_size)]
941
- status_msg = f"✅ Generated CoT metadata successfully (batch mode)\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
942
- return metadata_list, [""] * batch_size, status_msg
943
-
944
- # ========== PHASE 2: Audio Codes Generation (BATCH) ==========
945
- logger.info(f"Batch Phase 2: Generating audio codes for {batch_size} items...")
946
- phase2_start = time.time()
947
-
948
- # Format metadata as CoT
949
- cot_text = self._format_metadata_as_cot(metadata)
950
-
951
- # Build formatted prompt with CoT
952
- formatted_prompt = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
953
-
954
- # Replicate prompt for batch (all items have same prompt, differ by seeds)
955
- formatted_prompts = [formatted_prompt] * batch_size
956
-
957
- # Call backend-specific batch generation
958
- try:
959
- if self.llm_backend == "vllm":
960
- codes_outputs = self._run_vllm_batch(
961
- formatted_prompts=formatted_prompts,
962
- temperature=temperature,
963
- cfg_scale=cfg_scale,
964
- negative_prompt=negative_prompt,
965
- top_k=top_k,
966
- top_p=top_p,
967
- repetition_penalty=repetition_penalty,
968
- use_constrained_decoding=use_constrained_decoding,
969
- constrained_decoding_debug=constrained_decoding_debug,
970
- target_duration=target_duration,
971
- generation_phase="codes",
972
- caption=caption,
973
- lyrics=lyrics,
974
- cot_text=cot_text,
975
- seeds=seeds,
976
- )
977
- else: # pt backend
978
- codes_outputs = self._run_pt_batch(
979
- formatted_prompts=formatted_prompts,
980
- temperature=temperature,
981
- cfg_scale=cfg_scale,
982
- negative_prompt=negative_prompt,
983
- top_k=top_k,
984
- top_p=top_p,
985
- repetition_penalty=repetition_penalty,
986
- use_constrained_decoding=use_constrained_decoding,
987
- constrained_decoding_debug=constrained_decoding_debug,
988
- target_duration=target_duration,
989
- generation_phase="codes",
990
- caption=caption,
991
- lyrics=lyrics,
992
- cot_text=cot_text,
993
- seeds=seeds,
994
- )
995
- except Exception as e:
996
- error_msg = f"❌ Error in batch codes generation: {str(e)}"
997
- logger.error(error_msg)
998
- return [], [], error_msg
999
-
1000
- # Parse audio codes from each output
1001
- audio_codes_list = []
1002
- metadata_list = []
1003
- for output_text in codes_outputs:
1004
- _, audio_codes = self.parse_lm_output(output_text)
1005
- audio_codes_list.append(audio_codes)
1006
- metadata_list.append(metadata.copy()) # Same metadata for all
1007
-
1008
- phase2_time = time.time() - phase2_start
1009
-
1010
- # Log results
1011
- codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
1012
- logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
1013
-
1014
- status_msg = f"✅ Batch generation completed ({batch_size} items)\nPhase 1: CoT metadata\nPhase 2: {sum(codes_counts)} total codes ({codes_counts})\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
1015
- return metadata_list, audio_codes_list, status_msg
1016
-
1017
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
1018
  """
1019
  Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
@@ -1035,7 +1059,7 @@ class LLMHandler:
1035
  if is_negative_prompt:
1036
  # Unconditional prompt for CFG
1037
  # Check if user provided a meaningful negative prompt (not the default)
1038
- has_negative_prompt = negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
1039
 
1040
  if generation_phase == "cot":
1041
  # CoT phase unconditional prompt
@@ -1086,7 +1110,7 @@ class LLMHandler:
1086
  if is_negative_prompt:
1087
  # Unconditional prompt for codes phase
1088
  # Check if user provided a meaningful negative prompt
1089
- has_negative_prompt = negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
1090
 
1091
  # Use empty CoT for unconditional
1092
  cot_for_prompt = "<think>\n</think>"
@@ -1369,8 +1393,8 @@ class LLMHandler:
1369
 
1370
  try:
1371
  if self.llm_backend == "vllm":
1372
- output_text = self._run_vllm_from_formatted(
1373
- formatted_prompt=formatted_prompt,
1374
  temperature=temperature,
1375
  cfg_scale=cfg_scale,
1376
  negative_prompt=negative_prompt,
@@ -1393,8 +1417,8 @@ class LLMHandler:
1393
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
1394
 
1395
  # PyTorch backend
1396
- output_text = self._run_pt_from_formatted(
1397
- formatted_prompt=formatted_prompt,
1398
  temperature=temperature,
1399
  cfg_scale=cfg_scale,
1400
  negative_prompt=negative_prompt,
@@ -1459,26 +1483,12 @@ class LLMHandler:
1459
  eos_token_id = pad_token_id
1460
 
1461
  # Build logits processor for repetition penalty
1462
- logits_processor = LogitsProcessorList()
1463
- if repetition_penalty != 1.0:
1464
- logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
1465
 
1466
  with torch.no_grad():
1467
  for step in range(max_new_tokens):
1468
  # Forward pass
1469
- if past_key_values is None:
1470
- outputs = model(
1471
- input_ids=generated_ids,
1472
- **model_kwargs,
1473
- use_cache=use_cache,
1474
- )
1475
- else:
1476
- outputs = model(
1477
- input_ids=generated_ids[:, -1:],
1478
- past_key_values=past_key_values,
1479
- **model_kwargs,
1480
- use_cache=use_cache,
1481
- )
1482
 
1483
  # Get logits for the last position
1484
  next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
@@ -1491,41 +1501,18 @@ class LLMHandler:
1491
  for processor in logits_processor:
1492
  next_token_logits = processor(generated_ids, next_token_logits)
1493
 
1494
- # Apply top-k filtering
1495
- if top_k is not None and top_k > 0:
1496
- indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
1497
- next_token_logits[indices_to_remove] = float('-inf')
1498
-
1499
- # Apply top-p filtering
1500
- if top_p is not None and 0.0 < top_p < 1.0:
1501
- sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
1502
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
1503
- sorted_indices_to_remove = cumulative_probs > top_p
1504
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1505
- sorted_indices_to_remove[..., 0] = 0
1506
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1507
- next_token_logits[indices_to_remove] = float('-inf')
1508
 
1509
  # Apply temperature and sample
1510
- if temperature > 0:
1511
- next_token_logits = next_token_logits / temperature
1512
- probs = torch.softmax(next_token_logits, dim=-1)
1513
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1514
- else:
1515
- next_tokens = torch.argmax(next_token_logits, dim=-1)
1516
 
1517
  # Update constrained processor state
1518
- if constrained_processor is not None:
1519
- for b in range(next_tokens.shape[0]):
1520
- constrained_processor.update_state(next_tokens[b].item())
1521
 
1522
  # Check for EOS token
1523
- should_stop = False
1524
- if torch.any(next_tokens == eos_token_id):
1525
- should_stop = True
1526
- elif pad_token_id is not None and pad_token_id != eos_token_id:
1527
- if torch.any(next_tokens == pad_token_id):
1528
- should_stop = True
1529
 
1530
  # Append token to sequence
1531
  next_tokens_unsqueezed = next_tokens.unsqueeze(1)
@@ -1601,28 +1588,12 @@ class LLMHandler:
1601
  eos_token_id = pad_token_id
1602
 
1603
  # Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
1604
- logits_processor = LogitsProcessorList()
1605
- if repetition_penalty != 1.0:
1606
- logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
1607
 
1608
  with torch.no_grad():
1609
  for step in range(max_new_tokens):
1610
  # Forward pass for the entire batch (conditional + unconditional)
1611
- if past_key_values is None:
1612
- # First step: full forward pass
1613
- outputs = model(
1614
- input_ids=generated_ids,
1615
- **model_kwargs,
1616
- use_cache=use_cache,
1617
- )
1618
- else:
1619
- # Subsequent steps: only forward the last token (utilizing KV cache)
1620
- outputs = model(
1621
- input_ids=generated_ids[:, -1:],
1622
- past_key_values=past_key_values,
1623
- **model_kwargs,
1624
- use_cache=use_cache,
1625
- )
1626
 
1627
  # Get logits for the last position
1628
  next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
@@ -1645,45 +1616,20 @@ class LLMHandler:
1645
  for processor in logits_processor:
1646
  cfg_logits = processor(current_input_ids, cfg_logits)
1647
 
1648
- # Apply top-k filtering
1649
- if top_k is not None and top_k > 0:
1650
- indices_to_remove = cfg_logits < torch.topk(cfg_logits, top_k)[0][..., -1, None]
1651
- cfg_logits[indices_to_remove] = float('-inf')
1652
-
1653
- # Apply top-p (nucleus) filtering
1654
- if top_p is not None and 0.0 < top_p < 1.0:
1655
- sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
1656
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
1657
- # Remove tokens with cumulative probability above the threshold
1658
- sorted_indices_to_remove = cumulative_probs > top_p
1659
- # Shift the indices to the right to keep also the first token above the threshold
1660
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1661
- sorted_indices_to_remove[..., 0] = 0
1662
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1663
- cfg_logits[indices_to_remove] = float('-inf')
1664
 
1665
  # Apply temperature and sample
1666
- if temperature > 0:
1667
- cfg_logits = cfg_logits / temperature
1668
- probs = torch.softmax(cfg_logits, dim=-1)
1669
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1670
- else:
1671
- next_tokens = torch.argmax(cfg_logits, dim=-1)
1672
 
1673
  # Update constrained processor state AFTER sampling
1674
- if constrained_processor is not None:
1675
- for b in range(next_tokens.shape[0]):
1676
- constrained_processor.update_state(next_tokens[b].item())
1677
 
1678
  # Check for EOS token in conditional sequences BEFORE unsqueezing
1679
  # Stop if any conditional sequence generates EOS token
1680
  # next_tokens shape: [batch_size] (only conditional tokens)
1681
- should_stop = False
1682
- if torch.any(next_tokens == eos_token_id):
1683
- should_stop = True
1684
- elif pad_token_id is not None and pad_token_id != eos_token_id:
1685
- if torch.any(next_tokens == pad_token_id):
1686
- should_stop = True
1687
 
1688
  # Apply the same sampled tokens to both conditional and unconditional sequences
1689
  next_tokens_unsqueezed = next_tokens.unsqueeze(1)
 
5
  import os
6
  import traceback
7
  import time
8
+ from typing import Optional, Dict, Any, Tuple, List, Union
9
  from contextlib import contextmanager
10
 
11
  import yaml
 
85
  except Exception as e:
86
  return 0.9, False
87
 
88
+ def _has_meaningful_negative_prompt(self, negative_prompt: str) -> bool:
89
+ """Check if negative prompt is meaningful (not default/empty)"""
90
+ return negative_prompt and negative_prompt.strip() and negative_prompt.strip() != "NO USER INPUT"
91
+
92
+ def _build_logits_processor(self, repetition_penalty: float) -> LogitsProcessorList:
93
+ """Build logits processor list with repetition penalty if needed"""
94
+ logits_processor = LogitsProcessorList()
95
+ if repetition_penalty != 1.0:
96
+ logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
97
+ return logits_processor
98
+
99
+ def _setup_constrained_processor(
100
+ self,
101
+ use_constrained_decoding: bool,
102
+ constrained_decoding_debug: bool,
103
+ target_duration: Optional[float],
104
+ user_metadata: Optional[Dict[str, Optional[str]]],
105
+ stop_at_reasoning: bool,
106
+ skip_genres: bool,
107
+ skip_caption: bool,
108
+ skip_language: bool,
109
+ generation_phase: str,
110
+ is_batch: bool = False,
111
+ metadata_temperature: Optional[float] = None,
112
+ codes_temperature: Optional[float] = None,
113
+ ) -> Optional[MetadataConstrainedLogitsProcessor]:
114
+ """Setup and configure constrained processor for generation"""
115
+ use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
116
+
117
+ if not use_constrained_decoding and not use_phase_temperatures:
118
+ return None
119
+
120
+ # Reset processor state for new generation
121
+ self.constrained_processor.reset()
122
+
123
+ # Use shared processor, just update settings
124
+ self.constrained_processor.enabled = use_constrained_decoding
125
+ self.constrained_processor.debug = constrained_decoding_debug
126
+
127
+ # Phase temperatures only supported in single mode
128
+ if use_phase_temperatures:
129
+ self.constrained_processor.metadata_temperature = metadata_temperature
130
+ self.constrained_processor.codes_temperature = codes_temperature
131
+ else:
132
+ self.constrained_processor.metadata_temperature = None
133
+ self.constrained_processor.codes_temperature = None
134
+
135
+ self.constrained_processor.set_target_duration(target_duration)
136
+
137
+ # Batch mode uses default/disabled settings for these options
138
+ if is_batch:
139
+ self.constrained_processor.set_user_metadata(None)
140
+ self.constrained_processor.set_stop_at_reasoning(False)
141
+ self.constrained_processor.set_skip_genres(True)
142
+ self.constrained_processor.set_skip_caption(True)
143
+ self.constrained_processor.set_skip_language(True)
144
+ else:
145
+ # Single mode uses provided settings
146
+ self.constrained_processor.set_user_metadata(user_metadata)
147
+ self.constrained_processor.set_stop_at_reasoning(stop_at_reasoning)
148
+ self.constrained_processor.set_skip_genres(skip_genres)
149
+ self.constrained_processor.set_skip_caption(skip_caption)
150
+ self.constrained_processor.set_skip_language(skip_language)
151
+
152
+ # Set generation phase for phase-aware processing
153
+ self.constrained_processor.set_generation_phase(generation_phase)
154
+
155
+ return self.constrained_processor
156
+
157
+ def _build_unconditional_prompt(
158
+ self,
159
+ caption: str,
160
+ lyrics: str,
161
+ cot_text: str,
162
+ negative_prompt: str,
163
+ generation_phase: str,
164
+ is_batch: bool = False,
165
+ ) -> str:
166
+ """Build unconditional prompt for CFG based on generation phase and batch mode"""
167
+ if is_batch or generation_phase == "codes":
168
+ # Codes phase or batch mode: use empty CoT in unconditional prompt
169
+ return self.build_formatted_prompt_with_cot(
170
+ caption, lyrics, cot_text, is_negative_prompt=True, negative_prompt=negative_prompt
171
+ )
172
+ else:
173
+ # CoT phase (single mode only): unconditional prompt
174
+ # If negative_prompt is provided, use it as caption; otherwise remove caption and keep only lyrics
175
+ return self.build_formatted_prompt(
176
+ caption, lyrics, is_negative_prompt=True, generation_phase="cot", negative_prompt=negative_prompt
177
+ )
178
+
179
+ def _load_pytorch_model(self, model_path: str, device: str) -> Tuple[bool, str]:
180
+ """Load PyTorch model from path and return (success, status_message)"""
181
+ try:
182
+ self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
183
+ if not self.offload_to_cpu:
184
+ self.llm = self.llm.to(device).to(self.dtype)
185
+ else:
186
+ self.llm = self.llm.to("cpu").to(self.dtype)
187
+ self.llm.eval()
188
+ self.llm_backend = "pt"
189
+ self.llm_initialized = True
190
+ logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
191
+ status_msg = f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nBackend: PyTorch\nDevice: {device}"
192
+ return True, status_msg
193
+ except Exception as e:
194
+ return False, f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
195
+
196
+ def _apply_top_k_filter(self, logits: torch.Tensor, top_k: Optional[int]) -> torch.Tensor:
197
+ """Apply top-k filtering to logits"""
198
+ if top_k is not None and top_k > 0:
199
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
200
+ logits[indices_to_remove] = float('-inf')
201
+ return logits
202
+
203
+ def _apply_top_p_filter(self, logits: torch.Tensor, top_p: Optional[float]) -> torch.Tensor:
204
+ """Apply top-p (nucleus) filtering to logits"""
205
+ if top_p is not None and 0.0 < top_p < 1.0:
206
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
207
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
208
+ sorted_indices_to_remove = cumulative_probs > top_p
209
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
210
+ sorted_indices_to_remove[..., 0] = 0
211
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
212
+ logits[indices_to_remove] = float('-inf')
213
+ return logits
214
+
215
+ def _sample_tokens(self, logits: torch.Tensor, temperature: float) -> torch.Tensor:
216
+ """Sample tokens from logits with temperature"""
217
+ if temperature > 0:
218
+ logits = logits / temperature
219
+ probs = torch.softmax(logits, dim=-1)
220
+ return torch.multinomial(probs, num_samples=1).squeeze(1)
221
+ else:
222
+ return torch.argmax(logits, dim=-1)
223
+
224
+ def _check_eos_token(self, tokens: torch.Tensor, eos_token_id: int, pad_token_id: Optional[int]) -> bool:
225
+ """Check if any token in the batch is EOS or pad token"""
226
+ if torch.any(tokens == eos_token_id):
227
+ return True
228
+ if pad_token_id is not None and pad_token_id != eos_token_id:
229
+ if torch.any(tokens == pad_token_id):
230
+ return True
231
+ return False
232
+
233
+ def _update_constrained_processor_state(self, constrained_processor: Optional[MetadataConstrainedLogitsProcessor], tokens: torch.Tensor):
234
+ """Update constrained processor state with generated tokens"""
235
+ if constrained_processor is not None:
236
+ for b in range(tokens.shape[0]):
237
+ constrained_processor.update_state(tokens[b].item())
238
+
239
+ def _forward_pass(
240
+ self,
241
+ model: Any,
242
+ generated_ids: torch.Tensor,
243
+ model_kwargs: Dict[str, Any],
244
+ past_key_values: Optional[Any],
245
+ use_cache: bool,
246
+ ) -> Any:
247
+ """Perform forward pass with KV cache support"""
248
+ if past_key_values is None:
249
+ outputs = model(
250
+ input_ids=generated_ids,
251
+ **model_kwargs,
252
+ use_cache=use_cache,
253
+ )
254
+ else:
255
+ outputs = model(
256
+ input_ids=generated_ids[:, -1:],
257
+ past_key_values=past_key_values,
258
+ **model_kwargs,
259
+ use_cache=use_cache,
260
+ )
261
+ return outputs
262
+
263
+ def _normalize_batch_input(self, formatted_prompts: Union[str, List[str]]) -> Tuple[List[str], bool]:
264
+ """Normalize batch input: convert single string to list and return (list, is_batch)"""
265
+ is_batch = isinstance(formatted_prompts, list)
266
+ if is_batch:
267
+ return formatted_prompts, is_batch
268
+ else:
269
+ return [formatted_prompts], is_batch
270
+
271
  def initialize(
272
  self,
273
  checkpoint_dir: str,
 
333
  # vllm initialization failed, fallback to PyTorch
334
  if not self.llm_initialized:
335
  logger.warning("vllm initialization failed, falling back to PyTorch backend")
336
+ success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
337
+ if not success:
338
+ return status_msg, False
339
+ status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
 
 
 
 
 
 
 
 
 
340
  # If vllm initialization succeeded, self.llm_initialized should already be True
341
  else:
342
  # Use PyTorch backend (pt)
343
+ success, status_msg = self._load_pytorch_model(full_lm_model_path, device)
344
+ if not success:
345
+ return status_msg, False
 
 
 
 
 
 
 
 
 
 
346
 
347
  return status_msg, True
348
 
349
  except Exception as e:
350
+ return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
 
351
 
352
  def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
353
  """Initialize 5Hz LM model using vllm backend"""
 
393
  return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
394
  except Exception as e:
395
  self.llm_initialized = False
396
+ return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
397
 
398
+ def _run_vllm(
399
  self,
400
+ formatted_prompts: Union[str, List[str]],
401
  temperature: float,
402
  cfg_scale: float,
403
  negative_prompt: str,
 
406
  repetition_penalty: float,
407
  use_constrained_decoding: bool = True,
408
  constrained_decoding_debug: bool = False,
409
+ metadata_temperature: Optional[float] = None,
410
  codes_temperature: Optional[float] = None,
411
  target_duration: Optional[float] = None,
412
  user_metadata: Optional[Dict[str, Optional[str]]] = None,
 
418
  caption: str = "",
419
  lyrics: str = "",
420
  cot_text: str = "",
421
+ seeds: Optional[List[int]] = None,
422
+ ) -> Union[str, List[str]]:
423
+ """
424
+ Unified vllm generation function supporting both single and batch modes.
425
+ Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
426
+ Returns a single string for single mode, or a list of strings for batch mode.
427
+ """
428
  from nanovllm import SamplingParams
429
 
430
+ # Determine if batch mode
431
+ formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
432
+ batch_size = len(formatted_prompt_list)
433
+
434
  # Determine effective temperature for sampler
435
+ # Batch mode doesn't support phase temperatures, so use simple temperature
436
+ # Single mode supports phase temperatures
437
+ use_phase_temperatures = not is_batch and (metadata_temperature is not None or codes_temperature is not None)
438
  effective_sampler_temp = 1.0 if use_phase_temperatures else temperature
439
 
440
+ # Setup constrained processor
441
+ constrained_processor = self._setup_constrained_processor(
442
+ use_constrained_decoding=use_constrained_decoding or use_phase_temperatures,
443
+ constrained_decoding_debug=constrained_decoding_debug,
444
+ target_duration=target_duration,
445
+ user_metadata=user_metadata,
446
+ stop_at_reasoning=stop_at_reasoning,
447
+ skip_genres=skip_genres,
448
+ skip_caption=skip_caption,
449
+ skip_language=skip_language,
450
+ generation_phase=generation_phase,
451
+ is_batch=is_batch,
452
+ metadata_temperature=metadata_temperature,
453
+ codes_temperature=codes_temperature,
454
+ )
 
 
 
 
 
 
 
 
455
 
456
  sampling_params = SamplingParams(
457
  max_tokens=self.max_model_len - 64,
 
466
 
467
  if cfg_scale > 1.0:
468
  # Build unconditional prompt based on generation phase
469
+ formatted_unconditional_prompt = self._build_unconditional_prompt(
470
+ caption=caption,
471
+ lyrics=lyrics,
472
+ cot_text=cot_text,
473
+ negative_prompt=negative_prompt,
474
+ generation_phase=generation_phase,
475
+ is_batch=is_batch,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  )
477
  unconditional_prompts = [formatted_unconditional_prompt] * batch_size
478
 
479
  outputs = self.llm.generate(
480
+ formatted_prompt_list,
481
  sampling_params,
482
  unconditional_prompts=unconditional_prompts,
483
  )
484
  else:
485
+ outputs = self.llm.generate(formatted_prompt_list, sampling_params)
486
+
487
+ # Extract text from outputs
488
  output_texts = []
489
  for output in outputs:
490
  if hasattr(output, "outputs") and len(output.outputs) > 0:
 
495
  output_texts.append(output["text"])
496
  else:
497
  output_texts.append(str(output))
 
 
498
 
499
+ # Return single string for single mode, list for batch mode
500
+ return output_texts[0] if not is_batch else output_texts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
+ def _run_pt_single(
503
  self,
504
  formatted_prompt: str,
505
  temperature: float,
 
508
  top_k: Optional[int],
509
  top_p: Optional[float],
510
  repetition_penalty: float,
511
+ use_constrained_decoding: bool,
512
+ constrained_decoding_debug: bool,
513
+ target_duration: Optional[float],
514
+ user_metadata: Optional[Dict[str, Optional[str]]],
515
+ stop_at_reasoning: bool,
516
+ skip_genres: bool,
517
+ skip_caption: bool,
518
+ skip_language: bool,
519
+ generation_phase: str,
520
+ caption: str,
521
+ lyrics: str,
522
+ cot_text: str,
523
  ) -> str:
524
+ """Internal helper function for single-item PyTorch generation."""
525
  inputs = self.llm_tokenizer(
526
  formatted_prompt,
527
  return_tensors="pt",
 
529
  truncation=True,
530
  )
531
 
532
+ # Setup constrained processor
533
+ constrained_processor = self._setup_constrained_processor(
534
+ use_constrained_decoding=use_constrained_decoding,
535
+ constrained_decoding_debug=constrained_decoding_debug,
536
+ target_duration=target_duration,
537
+ user_metadata=user_metadata,
538
+ stop_at_reasoning=stop_at_reasoning,
539
+ skip_genres=skip_genres,
540
+ skip_caption=skip_caption,
541
+ skip_language=skip_language,
542
+ generation_phase=generation_phase,
543
+ is_batch=False,
544
+ )
 
 
 
 
 
 
 
 
545
 
546
  with self._load_model_context():
547
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
550
  max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
551
 
552
  # Build logits processor list (only for CFG and repetition penalty)
553
+ logits_processor = self._build_logits_processor(repetition_penalty)
 
 
 
 
554
 
555
  if cfg_scale > 1.0:
556
  # Build unconditional prompt based on generation phase
557
+ formatted_unconditional_prompt = self._build_unconditional_prompt(
558
+ caption=caption,
559
+ lyrics=lyrics,
560
+ cot_text=cot_text,
561
+ negative_prompt=negative_prompt,
562
+ generation_phase=generation_phase,
563
+ is_batch=False,
564
+ )
 
 
 
565
 
566
  # Tokenize both prompts together to ensure same length (with left padding)
567
  # Left padding is important for generation tasks
 
654
 
655
  output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
656
  return output_text
657
+
658
+ def _run_pt(
659
+ self,
660
+ formatted_prompts: Union[str, List[str]],
661
+ temperature: float,
662
+ cfg_scale: float,
663
+ negative_prompt: str,
664
+ top_k: Optional[int],
665
+ top_p: Optional[float],
666
+ repetition_penalty: float,
667
+ use_constrained_decoding: bool = True,
668
+ constrained_decoding_debug: bool = False,
669
+ target_duration: Optional[float] = None,
670
+ user_metadata: Optional[Dict[str, Optional[str]]] = None,
671
+ stop_at_reasoning: bool = False,
672
+ skip_genres: bool = True,
673
+ skip_caption: bool = False,
674
+ skip_language: bool = False,
675
+ generation_phase: str = "cot",
676
+ caption: str = "",
677
+ lyrics: str = "",
678
+ cot_text: str = "",
679
+ seeds: Optional[List[int]] = None,
680
+ ) -> Union[str, List[str]]:
681
+ """
682
+ Unified PyTorch generation function supporting both single and batch modes.
683
+ Accepts either a single formatted prompt (str) or a list of formatted prompts (List[str]).
684
+ Returns a single string for single mode, or a list of strings for batch mode.
685
+ Note: PyTorch backend processes batch items sequentially (doesn't support true batching efficiently).
686
+ """
687
+ # Determine if batch mode
688
+ formatted_prompt_list, is_batch = self._normalize_batch_input(formatted_prompts)
689
+
690
+ # For batch mode, process each item sequentially with different seeds
691
+ if is_batch:
692
+ output_texts = []
693
+ for i, formatted_prompt in enumerate(formatted_prompt_list):
694
+ # Set seed for this item if provided
695
+ if seeds and i < len(seeds):
696
+ torch.manual_seed(seeds[i])
697
+ if torch.cuda.is_available():
698
+ torch.cuda.manual_seed_all(seeds[i])
699
+
700
+ # Generate using single-item method with batch-mode defaults
701
+ output_text = self._run_pt_single(
702
+ formatted_prompt=formatted_prompt,
703
+ temperature=temperature,
704
+ cfg_scale=cfg_scale,
705
+ negative_prompt=negative_prompt,
706
+ top_k=top_k,
707
+ top_p=top_p,
708
+ repetition_penalty=repetition_penalty,
709
+ use_constrained_decoding=use_constrained_decoding,
710
+ constrained_decoding_debug=constrained_decoding_debug,
711
+ target_duration=target_duration,
712
+ user_metadata=None,
713
+ stop_at_reasoning=False,
714
+ skip_genres=True,
715
+ skip_caption=True,
716
+ skip_language=True,
717
+ generation_phase=generation_phase,
718
+ caption=caption,
719
+ lyrics=lyrics,
720
+ cot_text=cot_text,
721
+ )
722
+
723
+ output_texts.append(output_text)
724
+
725
+ return output_texts
726
+
727
+ # Single mode: process the formatted prompt
728
+ formatted_prompt = formatted_prompt_list[0]
729
+
730
+ return self._run_pt_single(
731
+ formatted_prompt=formatted_prompt,
732
+ temperature=temperature,
733
+ cfg_scale=cfg_scale,
734
+ negative_prompt=negative_prompt,
735
+ top_k=top_k,
736
+ top_p=top_p,
737
+ repetition_penalty=repetition_penalty,
738
+ use_constrained_decoding=use_constrained_decoding,
739
+ constrained_decoding_debug=constrained_decoding_debug,
740
+ target_duration=target_duration,
741
+ user_metadata=user_metadata,
742
+ stop_at_reasoning=stop_at_reasoning,
743
+ skip_genres=skip_genres,
744
+ skip_caption=skip_caption,
745
+ skip_language=skip_language,
746
+ generation_phase=generation_phase,
747
+ caption=caption,
748
+ lyrics=lyrics,
749
+ cot_text=cot_text,
750
+ )
751
+
752
  def has_all_metas(self, user_metadata: Optional[Dict[str, Optional[str]]]) -> bool:
753
  """Check if all required metadata are present."""
754
  if user_metadata is None:
 
799
  use_cot_caption: bool = True,
800
  use_cot_language: bool = True,
801
  is_format_caption: bool = False,
802
+ batch_size: Optional[int] = None,
803
+ seeds: Optional[List[int]] = None,
804
+ ) -> Union[Tuple[Dict[str, Any], str, str], Tuple[List[Dict[str, Any]], List[str], str]]:
805
  """Two-phase LM generation: CoT generation followed by audio codes generation.
806
 
807
  - infer_type='dit': Phase 1 only - generate CoT and return metas (no audio codes)
 
814
  If specified, constrained decoding will inject these values directly.
815
  use_cot_caption: Whether to generate caption in CoT (default True).
816
  use_cot_language: Whether to generate language in CoT (default True).
817
+ batch_size: Optional batch size for batch generation. If None or 1, returns single result.
818
+ If > 1, returns batch results (lists).
819
+ seeds: Optional list of seeds for batch generation (for reproducibility).
820
+ Only used when batch_size > 1.
821
+
822
+ Returns:
823
+ If batch_size is None or 1: (metadata, audio_codes, status_msg)
824
+ If batch_size > 1: (metadata_list, audio_codes_list, status_msg)
825
  """
826
  import time
827
+ import random
828
 
829
  infer_type = (infer_type or "").strip().lower()
830
  if infer_type not in {"dit", "llm_dit"}:
831
+ if batch_size and batch_size > 1:
832
+ return [], [], f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
833
  return {}, "", f"❌ invalid infer_type: {infer_type!r} (expected 'dit' or 'llm_dit')"
834
+
835
+ # Determine if batch mode
836
+ is_batch = batch_size and batch_size > 1
837
+ actual_batch_size = batch_size if is_batch else 1
838
+
839
+ # Initialize variables
840
  metadata = {}
841
  audio_codes = ""
842
  has_all_metas = self.has_all_metas(user_metadata)
 
 
843
  phase1_time = 0.0
844
  phase2_time = 0.0
845
 
846
+ # Handle seeds for batch mode
847
+ if is_batch:
848
+ if seeds is None:
849
+ seeds = [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size)]
850
+ elif len(seeds) < actual_batch_size:
851
+ seeds = list(seeds) + [random.randint(0, 2**32 - 1) for _ in range(actual_batch_size - len(seeds))]
852
+ else:
853
+ seeds = seeds[:actual_batch_size]
854
+
855
  # ========== PHASE 1: CoT Generation ==========
856
+ # Skip CoT if all metadata are user-provided OR caption is already formatted
857
+ if not has_all_metas and not is_format_caption:
858
+ if is_batch:
859
+ logger.info("Batch Phase 1: Generating CoT metadata (once for all items)...")
860
+ else:
861
+ logger.info("Phase 1: Generating CoT metadata...")
862
  phase1_start = time.time()
863
 
864
  # Build formatted prompt for CoT phase
865
  formatted_prompt = self.build_formatted_prompt(caption, lyrics, generation_phase="cot")
866
+
867
  logger.info(f"generate_with_stop_condition: formatted_prompt={formatted_prompt}")
868
  # Generate CoT (stop at </think>)
869
  cot_output_text, status = self.generate_from_formatted_prompt(
 
893
  phase1_time = time.time() - phase1_start
894
 
895
  if not cot_output_text:
896
+ if is_batch:
897
+ return [], [], status
898
  return {}, "", status
899
 
900
  # Parse metadata from CoT output
901
  metadata, _ = self.parse_lm_output(cot_output_text)
902
+ if is_batch:
903
+ logger.info(f"Batch Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
904
+ else:
905
+ logger.info(f"Phase 1 completed in {phase1_time:.2f}s. Generated metadata: {list(metadata.keys())}")
906
  else:
907
  # Use user-provided metadata
908
+ if is_batch:
909
+ logger.info("Batch Phase 1: Using user-provided metadata (skipping generation)")
910
+ else:
911
+ logger.info("Phase 1: Using user-provided metadata (skipping generation)")
912
  metadata = {k: v for k, v in user_metadata.items() if v is not None}
913
 
914
  # If infer_type is 'dit', stop here and return only metadata
915
  if infer_type == "dit":
916
+ if is_batch:
917
+ metadata_list = [metadata.copy() for _ in range(actual_batch_size)]
918
+ status_msg = f"✅ Generated CoT metadata successfully (batch mode)\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
919
+ return metadata_list, [""] * actual_batch_size, status_msg
920
+ else:
921
+ status_msg = f"✅ Generated CoT metadata successfully\nFields: {', '.join(metadata.keys())}\nPhase1: {phase1_time:.2f}s"
922
+ return metadata, "", status_msg
923
 
924
  # ========== PHASE 2: Audio Codes Generation ==========
925
+ if is_batch:
926
+ logger.info(f"Batch Phase 2: Generating audio codes for {actual_batch_size} items...")
927
+ else:
928
+ logger.info("Phase 2: Generating audio codes...")
929
  phase2_start = time.time()
930
 
931
  # Format metadata as CoT using YAML (matching training format)
 
934
  # Build formatted prompt with CoT for codes generation phase
935
  formatted_prompt_with_cot = self.build_formatted_prompt_with_cot(caption, lyrics, cot_text)
936
  logger.info(f"generate_with_stop_condition: formatted_prompt_with_cot={formatted_prompt_with_cot}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937
 
938
+ if is_batch:
939
+ # Batch mode: generate codes for all items
940
+ formatted_prompts = [formatted_prompt_with_cot] * actual_batch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
 
942
+ # Call backend-specific batch generation
943
+ try:
944
+ if self.llm_backend == "vllm":
945
+ codes_outputs = self._run_vllm(
946
+ formatted_prompts=formatted_prompts,
947
+ temperature=temperature,
948
+ cfg_scale=cfg_scale,
949
+ negative_prompt=negative_prompt,
950
+ top_k=top_k,
951
+ top_p=top_p,
952
+ repetition_penalty=repetition_penalty,
953
+ use_constrained_decoding=use_constrained_decoding,
954
+ constrained_decoding_debug=constrained_decoding_debug,
955
+ target_duration=target_duration,
956
+ generation_phase="codes",
957
+ caption=caption,
958
+ lyrics=lyrics,
959
+ cot_text=cot_text,
960
+ seeds=seeds,
961
+ )
962
+ else: # pt backend
963
+ codes_outputs = self._run_pt(
964
+ formatted_prompts=formatted_prompts,
965
+ temperature=temperature,
966
+ cfg_scale=cfg_scale,
967
+ negative_prompt=negative_prompt,
968
+ top_k=top_k,
969
+ top_p=top_p,
970
+ repetition_penalty=repetition_penalty,
971
+ use_constrained_decoding=use_constrained_decoding,
972
+ constrained_decoding_debug=constrained_decoding_debug,
973
+ target_duration=target_duration,
974
+ generation_phase="codes",
975
+ caption=caption,
976
+ lyrics=lyrics,
977
+ cot_text=cot_text,
978
+ seeds=seeds,
979
+ )
980
+ except Exception as e:
981
+ error_msg = f"❌ Error in batch codes generation: {str(e)}"
982
+ logger.error(error_msg)
983
+ return [], [], error_msg
984
+
985
+ # Parse audio codes from each output
986
+ audio_codes_list = []
987
+ metadata_list = []
988
+ for output_text in codes_outputs:
989
+ _, audio_codes_item = self.parse_lm_output(output_text)
990
+ audio_codes_list.append(audio_codes_item)
991
+ metadata_list.append(metadata.copy()) # Same metadata for all
992
+
993
+ phase2_time = time.time() - phase2_start
994
+
995
+ # Log results
996
+ codes_counts = [len(codes.split('<|audio_code_')) - 1 if codes else 0 for codes in audio_codes_list]
997
+ logger.info(f"Batch Phase 2 completed in {phase2_time:.2f}s. Generated codes: {codes_counts}")
998
+
999
+ status_msg = f"✅ Batch generation completed ({actual_batch_size} items)\nPhase 1: CoT metadata\nPhase 2: {sum(codes_counts)} total codes ({codes_counts})\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
1000
+ return metadata_list, audio_codes_list, status_msg
1001
+ else:
1002
+ # Single mode: generate codes for one item
1003
+ codes_output_text, status = self.generate_from_formatted_prompt(
1004
+ formatted_prompt=formatted_prompt_with_cot,
1005
+ cfg={
1006
+ "temperature": temperature,
1007
+ "cfg_scale": cfg_scale,
1008
+ "negative_prompt": negative_prompt,
1009
+ "top_k": top_k,
1010
+ "top_p": top_p,
1011
+ "repetition_penalty": repetition_penalty,
1012
+ "target_duration": target_duration,
1013
+ "user_metadata": None, # No user metadata injection in Phase 2
1014
+ "skip_caption": True, # Skip caption since CoT is already included
1015
+ "skip_language": True, # Skip language since CoT is already included
1016
+ "generation_phase": "codes",
1017
+ # Pass context for building unconditional prompt in codes phase
1018
+ "caption": caption,
1019
+ "lyrics": lyrics,
1020
+ "cot_text": cot_text,
1021
+ },
1022
  use_constrained_decoding=use_constrained_decoding,
1023
  constrained_decoding_debug=constrained_decoding_debug,
1024
+ stop_at_reasoning=False, # Generate codes until EOS
 
 
 
 
1025
  )
1026
 
1027
+ if not codes_output_text:
1028
+ return metadata, "", status
1029
 
1030
+ phase2_time = time.time() - phase2_start
 
1031
 
1032
+ # Parse audio codes from output (metadata should be same as Phase 1)
1033
+ _, audio_codes = self.parse_lm_output(codes_output_text)
1034
+
1035
+ codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
1036
+ logger.info(f"Phase 2 completed in {phase2_time:.2f}s. Generated {codes_count} audio codes")
1037
+
1038
+ status_msg = f"✅ Generated successfully (2-phase)\nPhase 1: CoT metadata\nPhase 2: {codes_count} audio codes\nPhase1: {phase1_time:.2f}s, Phase2: {phase2_time:.2f}s"
1039
+ return metadata, audio_codes, status_msg
1040
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1041
  def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False, generation_phase: str = "cot", negative_prompt: str = "NO USER INPUT") -> str:
1042
  """
1043
  Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
 
1059
  if is_negative_prompt:
1060
  # Unconditional prompt for CFG
1061
  # Check if user provided a meaningful negative prompt (not the default)
1062
+ has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
1063
 
1064
  if generation_phase == "cot":
1065
  # CoT phase unconditional prompt
 
1110
  if is_negative_prompt:
1111
  # Unconditional prompt for codes phase
1112
  # Check if user provided a meaningful negative prompt
1113
+ has_negative_prompt = self._has_meaningful_negative_prompt(negative_prompt)
1114
 
1115
  # Use empty CoT for unconditional
1116
  cot_for_prompt = "<think>\n</think>"
 
1393
 
1394
  try:
1395
  if self.llm_backend == "vllm":
1396
+ output_text = self._run_vllm(
1397
+ formatted_prompts=formatted_prompt,
1398
  temperature=temperature,
1399
  cfg_scale=cfg_scale,
1400
  negative_prompt=negative_prompt,
 
1417
  return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
1418
 
1419
  # PyTorch backend
1420
+ output_text = self._run_pt(
1421
+ formatted_prompts=formatted_prompt,
1422
  temperature=temperature,
1423
  cfg_scale=cfg_scale,
1424
  negative_prompt=negative_prompt,
 
1483
  eos_token_id = pad_token_id
1484
 
1485
  # Build logits processor for repetition penalty
1486
+ logits_processor = self._build_logits_processor(repetition_penalty)
 
 
1487
 
1488
  with torch.no_grad():
1489
  for step in range(max_new_tokens):
1490
  # Forward pass
1491
+ outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
 
 
 
 
 
 
 
 
 
 
 
 
1492
 
1493
  # Get logits for the last position
1494
  next_token_logits = outputs.logits[:, -1, :] # [batch_size, vocab_size]
 
1501
  for processor in logits_processor:
1502
  next_token_logits = processor(generated_ids, next_token_logits)
1503
 
1504
+ # Apply top-k and top-p filtering
1505
+ next_token_logits = self._apply_top_k_filter(next_token_logits, top_k)
1506
+ next_token_logits = self._apply_top_p_filter(next_token_logits, top_p)
 
 
 
 
 
 
 
 
 
 
 
1507
 
1508
  # Apply temperature and sample
1509
+ next_tokens = self._sample_tokens(next_token_logits, temperature)
 
 
 
 
 
1510
 
1511
  # Update constrained processor state
1512
+ self._update_constrained_processor_state(constrained_processor, next_tokens)
 
 
1513
 
1514
  # Check for EOS token
1515
+ should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
 
 
 
 
 
1516
 
1517
  # Append token to sequence
1518
  next_tokens_unsqueezed = next_tokens.unsqueeze(1)
 
1588
  eos_token_id = pad_token_id
1589
 
1590
  # Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
1591
+ logits_processor = self._build_logits_processor(repetition_penalty)
 
 
1592
 
1593
  with torch.no_grad():
1594
  for step in range(max_new_tokens):
1595
  # Forward pass for the entire batch (conditional + unconditional)
1596
+ outputs = self._forward_pass(model, generated_ids, model_kwargs, past_key_values, use_cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1597
 
1598
  # Get logits for the last position
1599
  next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
 
1616
  for processor in logits_processor:
1617
  cfg_logits = processor(current_input_ids, cfg_logits)
1618
 
1619
+ # Apply top-k and top-p filtering
1620
+ cfg_logits = self._apply_top_k_filter(cfg_logits, top_k)
1621
+ cfg_logits = self._apply_top_p_filter(cfg_logits, top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
1622
 
1623
  # Apply temperature and sample
1624
+ next_tokens = self._sample_tokens(cfg_logits, temperature)
 
 
 
 
 
1625
 
1626
  # Update constrained processor state AFTER sampling
1627
+ self._update_constrained_processor_state(constrained_processor, next_tokens)
 
 
1628
 
1629
  # Check for EOS token in conditional sequences BEFORE unsqueezing
1630
  # Stop if any conditional sequence generates EOS token
1631
  # next_tokens shape: [batch_size] (only conditional tokens)
1632
+ should_stop = self._check_eos_token(next_tokens, eos_token_id, pad_token_id)
 
 
 
 
 
1633
 
1634
  # Apply the same sampled tokens to both conditional and unconditional sequences
1635
  next_tokens_unsqueezed = next_tokens.unsqueeze(1)
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py CHANGED
@@ -68,10 +68,16 @@ class ModelRunner:
68
  self.model = Qwen3ForCausalLM(hf_config)
69
  load_model(self.model, config.model)
70
  self.sampler = Sampler()
 
 
 
 
 
71
  self.warmup_model()
72
  self.allocate_kv_cache()
73
  if not self.enforce_eager:
74
  self.capture_cudagraph()
 
75
  torch.set_default_device("cpu")
76
  torch.set_default_dtype(default_dtype)
77
 
@@ -84,6 +90,24 @@ class ModelRunner:
84
  self.shm = SharedMemory(name="nanovllm")
85
  self.loop()
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def exit(self):
88
  if self.world_size > 1:
89
  self.shm.close()
@@ -216,57 +240,49 @@ class ModelRunner:
216
  return input_ids, positions
217
 
218
  def prepare_decode(self, seqs: list[Sequence]):
219
- input_ids = []
220
- positions = []
221
- slot_mapping = []
222
- context_lens = []
223
- for seq in seqs:
224
- input_ids.append(seq.last_token)
225
- positions.append(len(seq) - 1)
226
- context_lens.append(len(seq))
227
- slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
228
- input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
229
- positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
230
- slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
231
- context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
 
 
232
  block_tables = self.prepare_block_tables(seqs)
233
  set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
234
  return input_ids, positions
235
 
236
  def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
237
- """Prepare sampling parameters. For CFG batch, only return parameters for conditional sequences."""
238
  if is_cfg_batch:
239
- # For CFG batch, seqs contains [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
240
- # We only need parameters for conditional sequences (first half)
241
- num_cond = len(seqs) // 2
242
- temperatures = []
243
- cfg_scales = []
244
- top_ks = []
245
- top_ps = []
246
- repetition_penalties = []
247
- for seq in seqs[:num_cond]:
248
- temperatures.append(seq.temperature)
249
- cfg_scales.append(seq.cfg_scale)
250
- top_ks.append(seq.top_k if seq.top_k is not None else 0)
251
- top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
252
- repetition_penalties.append(seq.repetition_penalty)
253
  else:
254
- temperatures = []
255
- cfg_scales = []
256
- top_ks = []
257
- top_ps = []
258
- repetition_penalties = []
259
- for seq in seqs:
260
- temperatures.append(seq.temperature)
261
- cfg_scales.append(seq.cfg_scale)
262
- top_ks.append(seq.top_k if seq.top_k is not None else 0)
263
- top_ps.append(seq.top_p if seq.top_p is not None else 1.0)
264
- repetition_penalties.append(seq.repetition_penalty)
265
- temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
266
- cfg_scales = torch.tensor(cfg_scales, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
267
- top_ks = torch.tensor(top_ks, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
268
- top_ps = torch.tensor(top_ps, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
269
- repetition_penalties = torch.tensor(repetition_penalties, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
 
 
270
  return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
271
 
272
  @torch.inference_mode()
 
68
  self.model = Qwen3ForCausalLM(hf_config)
69
  load_model(self.model, config.model)
70
  self.sampler = Sampler()
71
+
72
+ # Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
73
+ # Must be called before warmup_model() since it uses these buffers
74
+ self._allocate_sample_buffers()
75
+
76
  self.warmup_model()
77
  self.allocate_kv_cache()
78
  if not self.enforce_eager:
79
  self.capture_cudagraph()
80
+
81
  torch.set_default_device("cpu")
82
  torch.set_default_dtype(default_dtype)
83
 
 
90
  self.shm = SharedMemory(name="nanovllm")
91
  self.loop()
92
 
93
+ def _allocate_sample_buffers(self):
94
+ """Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
95
+ max_bs = self.config.max_num_seqs
96
+
97
+ # Pre-allocate pinned memory buffers on CPU for fast transfer
98
+ # Must explicitly specify device="cpu" since default device may be "cuda"
99
+ self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
100
+ self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
101
+ self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
102
+ self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
103
+ self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
104
+
105
+ # Pre-allocate decode buffers on CPU with pinned memory
106
+ self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
107
+ self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
108
+ self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
109
+ self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
110
+
111
  def exit(self):
112
  if self.world_size > 1:
113
  self.shm.close()
 
240
  return input_ids, positions
241
 
242
  def prepare_decode(self, seqs: list[Sequence]):
243
+ """Optimized decode preparation using pre-allocated buffers."""
244
+ bs = len(seqs)
245
+
246
+ # Use pre-allocated CPU buffers
247
+ for i, seq in enumerate(seqs):
248
+ self._cpu_input_ids[i] = seq.last_token
249
+ self._cpu_positions[i] = len(seq) - 1
250
+ self._cpu_context_lens[i] = len(seq)
251
+ self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
252
+
253
+ # Transfer to GPU using sliced views
254
+ input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
255
+ positions = self._cpu_positions[:bs].cuda(non_blocking=True)
256
+ slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
257
+ context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
258
  block_tables = self.prepare_block_tables(seqs)
259
  set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
260
  return input_ids, positions
261
 
262
  def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
263
+ """Optimized sample preparation using pre-allocated buffers."""
264
  if is_cfg_batch:
265
+ num_seqs = len(seqs) // 2
266
+ target_seqs = seqs[:num_seqs]
 
 
 
 
 
 
 
 
 
 
 
 
267
  else:
268
+ num_seqs = len(seqs)
269
+ target_seqs = seqs
270
+
271
+ # Fill pre-allocated CPU buffers
272
+ for i, seq in enumerate(target_seqs):
273
+ self._cpu_temperatures[i] = seq.temperature
274
+ self._cpu_cfg_scales[i] = seq.cfg_scale
275
+ self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
276
+ self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
277
+ self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
278
+
279
+ # Transfer to GPU using sliced views (single batched transfer)
280
+ temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
281
+ cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
282
+ top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True)
283
+ top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True)
284
+ repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True)
285
+
286
  return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
287
 
288
  @torch.inference_mode()
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py CHANGED
@@ -3,12 +3,88 @@ from torch import nn
3
  from typing import Optional
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class Sampler(nn.Module):
7
 
8
  def __init__(self):
9
  super().__init__()
10
 
11
- @torch.compile
12
  def forward(
13
  self,
14
  logits: torch.Tensor,
@@ -19,56 +95,34 @@ class Sampler(nn.Module):
19
  input_ids: Optional[torch.Tensor] = None,
20
  ):
21
  """
22
- Sample tokens from logits with optional top-k, top-p, and repetition penalty.
23
 
24
- Args:
25
- logits: [batch_size, vocab_size] logits tensor
26
- temperatures: [batch_size] temperature values
27
- top_ks: Optional [batch_size] top-k values (None or 0 means no top-k filtering)
28
- top_ps: Optional [batch_size] top-p values (None or 1.0 means no top-p filtering)
29
- repetition_penalties: Optional [batch_size] repetition penalty values (1.0 means no penalty)
30
- input_ids: Optional [batch_size, seq_len] input token ids for repetition penalty
31
  """
32
- batch_size, vocab_size = logits.shape
33
-
34
- # Note: Repetition penalty is applied in ModelRunner before calling sampler
35
- # This allows us to use the full sequence context
36
-
37
  # Apply temperature
38
  logits = logits.float().div_(temperatures.unsqueeze(dim=1))
39
 
40
- # Apply top-k filtering if specified
41
- if top_ks is not None:
42
- for i in range(batch_size):
43
- top_k = top_ks[i].item()
44
- if top_k > 0 and top_k < vocab_size:
45
- # Get top-k logits, set others to -inf
46
- top_k_logits, top_k_indices = torch.topk(logits[i], int(top_k), dim=-1)
47
- filtered_logits = torch.full_like(logits[i], float('-inf'))
48
- filtered_logits[top_k_indices] = top_k_logits
49
- logits[i] = filtered_logits
50
 
51
- # Apply top-p (nucleus) filtering if specified
52
- if top_ps is not None:
53
- probs = torch.softmax(logits, dim=-1)
54
- for i in range(batch_size):
55
- top_p = top_ps[i].item()
56
- if 0.0 < top_p < 1.0:
57
- # Sort probabilities in descending order
58
- sorted_probs, sorted_indices = torch.sort(probs[i], descending=True)
59
- # Calculate cumulative probabilities
60
- cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
61
- # Find the cutoff point
62
- cutoff_idx = (cumsum_probs <= top_p).sum().item()
63
- if cutoff_idx < len(sorted_indices):
64
- cutoff_idx += 1 # Include one more token to ensure we have at least one
65
- # Create mask for tokens to keep
66
- mask = torch.zeros_like(probs[i])
67
- mask[sorted_indices[:cutoff_idx]] = 1.0
68
- # Apply mask: set filtered tokens to -inf
69
- logits[i] = torch.where(mask > 0, logits[i], torch.tensor(float('-inf'), device=logits.device))
70
 
71
- # Sample using Gumbel-max trick (equivalent to sampling from softmax)
72
- probs = torch.softmax(logits, dim=-1)
73
- sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
74
- return sample_tokens
 
 
 
 
 
 
3
  from typing import Optional
4
 
5
 
6
+ def apply_top_k_top_p(
7
+ logits: torch.Tensor,
8
+ k: Optional[torch.Tensor],
9
+ p: Optional[torch.Tensor],
10
+ ) -> torch.Tensor:
11
+ """Apply top-k and top-p masks to the logits (vLLM style).
12
+
13
+ The logits tensor is updated in-place.
14
+ """
15
+ if p is None:
16
+ if k is None:
17
+ return logits
18
+ # Avoid sorting vocab for top-k only case
19
+ return apply_top_k_only(logits, k)
20
+
21
+ # Need to sort for top-p
22
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
23
+
24
+ if k is not None:
25
+ # Apply top-k first
26
+ vocab_size = logits_sort.size(1)
27
+ # Clamp k to valid range
28
+ k_clamped = k.clamp(1, vocab_size).long()
29
+ top_k_mask_idx = vocab_size - k_clamped # shape: [B]
30
+ # Get the threshold value for each batch
31
+ top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
32
+ top_k_mask = logits_sort < top_k_thresh
33
+ logits_sort.masked_fill_(top_k_mask, float('-inf'))
34
+
35
+ # Apply top-p
36
+ probs_sort = logits_sort.softmax(dim=-1)
37
+ probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
38
+ top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
39
+ # Ensure at least one token is kept
40
+ top_p_mask[:, -1] = False
41
+ logits_sort.masked_fill_(top_p_mask, float('-inf'))
42
+
43
+ # Re-sort back to original positions
44
+ logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
45
+ return logits
46
+
47
+
48
+ def apply_top_k_only(
49
+ logits: torch.Tensor,
50
+ k: torch.Tensor,
51
+ ) -> torch.Tensor:
52
+ """Apply top-k mask without sorting the entire vocab (vLLM style).
53
+
54
+ This is much faster than sorting for top-k only cases.
55
+ The logits tensor is updated in-place.
56
+ """
57
+ vocab_size = logits.shape[1]
58
+ # Handle cases where k >= vocab_size (no filtering needed)
59
+ no_top_k_mask = (k <= 0) | (k >= vocab_size)
60
+ # Set invalid k to 1 so we can still gather
61
+ k_safe = k.masked_fill(no_top_k_mask, 1).long()
62
+ # NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
63
+ max_top_k = int(k_safe.max().clamp(max=vocab_size))
64
+
65
+ # Get top-k values for all batches
66
+ # topk.values has shape [batch_size, max_top_k]
67
+ topk_values = logits.topk(max_top_k, dim=1).values
68
+
69
+ # Convert k to 0-based index: we want the k-th largest value (index k-1)
70
+ # Clamp to valid range for gather
71
+ k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
72
+ # Gather the threshold value (the k-th largest)
73
+ top_k_thresh = topk_values.gather(1, k_index)
74
+
75
+ # For rows with no top-k filtering, set threshold to -inf so nothing gets masked
76
+ top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
77
+
78
+ # Mask all values below the threshold
79
+ logits.masked_fill_(logits < top_k_thresh, float('-inf'))
80
+ return logits
81
+
82
+
83
  class Sampler(nn.Module):
84
 
85
  def __init__(self):
86
  super().__init__()
87
 
 
88
  def forward(
89
  self,
90
  logits: torch.Tensor,
 
95
  input_ids: Optional[torch.Tensor] = None,
96
  ):
97
  """
98
+ Sample tokens from logits with optional top-k and top-p filtering.
99
 
100
+ Condition checking is done OUTSIDE the compiled function to avoid
101
+ graph breaks from .any() calls.
 
 
 
 
 
102
  """
 
 
 
 
 
103
  # Apply temperature
104
  logits = logits.float().div_(temperatures.unsqueeze(dim=1))
105
 
106
+ # Check conditions OUTSIDE compiled code to avoid graph breaks
107
+ # These .any() calls cause CPU-GPU sync, but we do it once here
108
+ # instead of inside the compiled function
109
+ need_topk = top_ks is not None and bool((top_ks > 0).any()) and bool((top_ks < logits.shape[1]).any())
110
+ need_topp = top_ps is not None and bool((top_ps < 1.0).any()) and bool((top_ps > 0.0).any())
 
 
 
 
 
111
 
112
+ if need_topk or need_topp:
113
+ # Apply filtering (this part is not compiled due to dynamic control flow)
114
+ logits = apply_top_k_top_p(
115
+ logits,
116
+ top_ks if need_topk else None,
117
+ top_ps if need_topp else None,
118
+ )
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ # Sample using compiled function
121
+ return self._sample(logits)
122
+
123
+ @torch.compile(dynamic=True)
124
+ def _sample(self, logits: torch.Tensor) -> torch.Tensor:
125
+ """Compiled sampling kernel - no graph breaks here."""
126
+ probs = logits.softmax(dim=-1, dtype=torch.float32)
127
+ q = torch.empty_like(probs).exponential_()
128
+ return probs.div(q).argmax(dim=-1)
acestep/third_parts/nano-vllm/pyproject.toml CHANGED
@@ -15,8 +15,6 @@ dependencies = [
15
  "triton-windows>=3.0.0; sys_platform == 'win32'",
16
  "triton>=3.0.0; sys_platform != 'win32'",
17
  "transformers>=4.51.0",
18
- "flash-attn @ https://github.com/sdbds/flash-attention-for-windows/releases/download/2.8.3/flash_attn-2.8.3+cu128torch2.8.0cxx11abiFALSEfullbackward-cp311-cp311-win_amd64.whl; sys_platform == 'win32'",
19
- "flash-attn; sys_platform != 'win32'",
20
  "xxhash",
21
  ]
22
 
 
15
  "triton-windows>=3.0.0; sys_platform == 'win32'",
16
  "triton>=3.0.0; sys_platform != 'win32'",
17
  "transformers>=4.51.0",
 
 
18
  "xxhash",
19
  ]
20
 
profile_inference.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Profiling script for acestep/inference.py using cProfile
4
+
5
+ Usage:
6
+ python profile_inference.py
7
+ python profile_inference.py --warmup
8
+ """
9
+
10
+ import cProfile
11
+ import pstats
12
+ import io
13
+ import time
14
+ import argparse
15
+ import sys
16
+ import os
17
+
18
+ # Add project root to path
19
+ project_root = os.path.abspath(os.path.dirname(__file__))
20
+ if project_root not in sys.path:
21
+ sys.path.insert(0, project_root)
22
+
23
+ from acestep.inference import generate_music, GenerationParams, GenerationConfig
24
+ from acestep.handler import AceStepHandler
25
+ from acestep.llm_inference import LLMHandler
26
+ import json
27
+ from typing import Tuple
28
+
29
+
30
+ def profile_with_cprofile(dit_handler, llm_handler, params, config, warmup=False):
31
+ """Profile using Python's built-in cProfile.
32
+
33
+ Args:
34
+ warmup: If True, run once for warmup before profiling (default: False)
35
+ """
36
+ print("=" * 80)
37
+ print("Profiling with cProfile")
38
+ print("=" * 80)
39
+
40
+ # Warmup run (to exclude PyTorch compilation overhead)
41
+ if warmup:
42
+ print("\n[Warmup] Running first generation to warm up (PyTorch compilation, etc.)...")
43
+ warmup_start = time.time()
44
+ params.use_cot_metas = False
45
+ config.is_format_caption = True
46
+ config.use_constrained_decoding = False
47
+ warmup_result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
48
+ warmup_time = time.time() - warmup_start
49
+ print(f"[Warmup] Completed in {warmup_time:.2f}s")
50
+ if not warmup_result.success:
51
+ print(f"[Warmup] ⚠ Warmup generation failed: {warmup_result.error}")
52
+ return warmup_result
53
+
54
+ # Actual profiling run (first inference)
55
+ print("\n[Profiling] Running first generation for profiling...")
56
+ profiler = cProfile.Profile()
57
+ profiler.enable()
58
+
59
+ profiling_start = time.time()
60
+ try:
61
+ result = generate_music(dit_handler, llm_handler, params, config, save_dir="./")
62
+ finally:
63
+ profiler.disable()
64
+ profiling_time = time.time() - profiling_start
65
+
66
+ # Create stats
67
+ s = io.StringIO()
68
+ ps = pstats.Stats(profiler, stream=s)
69
+ ps.sort_stats('cumulative')
70
+
71
+ print(f"\n[Profiling] Completed in {profiling_time:.2f}s")
72
+ print("\nTop 30 functions by cumulative time:")
73
+ print("-" * 80)
74
+ ps.print_stats(30)
75
+
76
+ print("\nTop 30 functions by total time:")
77
+ print("-" * 80)
78
+ ps.sort_stats('tottime')
79
+ ps.print_stats(30)
80
+
81
+ # Save detailed report to file
82
+ output_file = "profile_cprofile.txt"
83
+ with open(output_file, 'w') as f:
84
+ # Create a new Stats object with file as stream
85
+ ps_file = pstats.Stats(profiler, stream=f)
86
+ ps_file.sort_stats('cumulative')
87
+ ps_file.print_stats()
88
+ print(f"\nDetailed profile saved to: {output_file}")
89
+
90
+ return result
91
+
92
+
93
+ def main():
94
+ parser = argparse.ArgumentParser(description="Profile acestep/inference.py")
95
+ parser.add_argument(
96
+ "--checkpoint-dir",
97
+ type=str,
98
+ default="./checkpoints",
99
+ help="Path to checkpoints directory"
100
+ )
101
+ parser.add_argument(
102
+ "--config-path",
103
+ type=str,
104
+ default="acestep-v15-turbo-rl",
105
+ help="Model config path"
106
+ )
107
+ parser.add_argument(
108
+ "--device",
109
+ type=str,
110
+ default="cuda",
111
+ help="Device to use (cuda/cpu)"
112
+ )
113
+ parser.add_argument(
114
+ "--lm-model",
115
+ type=str,
116
+ default="acestep-5Hz-lm-0.6B-v3",
117
+ help="LM model path"
118
+ )
119
+ parser.add_argument(
120
+ "--lm-backend",
121
+ type=str,
122
+ default="vllm",
123
+ help="LM backend"
124
+ )
125
+ parser.add_argument(
126
+ "--warmup",
127
+ action="store_true",
128
+ help="Enable warmup run before profiling (default: False, profile first run)"
129
+ )
130
+
131
+ args = parser.parse_args()
132
+
133
+ # Initialize handlers
134
+ print("Initializing handlers...")
135
+ dit_handler = AceStepHandler()
136
+ llm_handler = LLMHandler()
137
+
138
+ # Initialize DiT
139
+ print(" - Initializing DiT model...")
140
+ status_dit, success_dit = dit_handler.initialize_service(
141
+ project_root=project_root,
142
+ config_path=args.config_path,
143
+ device=args.device,
144
+ )
145
+ if not success_dit:
146
+ print(f" ❌ DiT initialization failed: {status_dit}")
147
+ sys.exit(1)
148
+ print(" ✓ DiT model initialized")
149
+
150
+ # Initialize LLM
151
+ print(" - Initializing LLM model...")
152
+ status_llm, success_llm = llm_handler.initialize(
153
+ checkpoint_dir=args.checkpoint_dir,
154
+ lm_model_path=args.lm_model,
155
+ backend=args.lm_backend,
156
+ device=args.device,
157
+ )
158
+ if success_llm:
159
+ print(" ✓ LM model initialized")
160
+ else:
161
+ print(f" ⚠ LM initialization failed: {status_llm}")
162
+
163
+ # Load test parameters from example file (same as acestep/inference.py)
164
+ def load_example_config(example_file: str) -> Tuple[GenerationParams, GenerationConfig]:
165
+ """Load configuration from an example JSON file."""
166
+ try:
167
+ with open(example_file, 'r', encoding='utf-8') as f:
168
+ data = json.load(f)
169
+
170
+ # Convert example format to GenerationParams and GenerationConfig
171
+ # Handle time signature format (example uses "4" instead of "4/4")
172
+ time_sig = data.get('timesignature', '')
173
+
174
+ params = GenerationParams(
175
+ caption=data.get('caption', ''),
176
+ lyrics=data.get('lyrics', ''),
177
+ bpm=data.get('bpm'),
178
+ keyscale=data.get('keyscale', ''),
179
+ timesignature=time_sig,
180
+ vocal_language=data.get('language', 'unknown'),
181
+ duration=data.get('duration'),
182
+ thinking=data.get('think', False),
183
+ inference_steps=data.get('inference_steps', 8),
184
+ seed=42,
185
+ )
186
+
187
+ config = GenerationConfig()
188
+ config.batch_size = data.get('batch_size', 1)
189
+
190
+ return params, config
191
+
192
+ except Exception as e:
193
+ print(f" ⚠ Failed to load example file: {e}")
194
+ return None, None
195
+
196
+ # Load production example (same as acestep/inference.py)
197
+ example_file = os.path.join(project_root, "examples", "text2music", "example_05.json")
198
+
199
+ if not os.path.exists(example_file):
200
+ print(f"\n ❌ Example file not found: {example_file}")
201
+ print(" Please ensure the examples directory exists.")
202
+ sys.exit(1)
203
+
204
+ print(f"\n Loading example: {os.path.basename(example_file)}")
205
+ params, config = load_example_config(example_file)
206
+
207
+ if not params or not config:
208
+ print(" ❌ Failed to load example configuration")
209
+ sys.exit(1)
210
+
211
+ print("\n" + "=" * 80)
212
+ print("Starting profiling...")
213
+ print("=" * 80)
214
+
215
+ result = profile_with_cprofile(dit_handler, llm_handler, params, config, warmup=args.warmup)
216
+
217
+ if result and not result.success:
218
+ print(f"\n⚠ Generation failed: {result.error}")
219
+
220
+
221
+ if __name__ == "__main__":
222
+ main()
223
+