ChuxiJ commited on
Commit
bf5e1fd
·
1 Parent(s): 8ff7c0c

refact handler

Browse files
acestep/acestep_v15_pipeline.py CHANGED
@@ -10,6 +10,8 @@ for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL
10
  os.environ.pop(proxy_var, None)
11
 
12
  from .handler import AceStepHandler
 
 
13
  from .gradio_ui import create_gradio_interface
14
 
15
 
@@ -20,11 +22,13 @@ def create_demo():
20
  Returns:
21
  Gradio Blocks instance
22
  """
23
- # Create handler instance (business logic processor)
24
- handler = AceStepHandler()
 
 
25
 
26
- # Create Gradio interface
27
- demo = create_gradio_interface(handler)
28
 
29
  return demo
30
 
 
10
  os.environ.pop(proxy_var, None)
11
 
12
  from .handler import AceStepHandler
13
+ from .llm_inference import LLMHandler
14
+ from .dataset_handler import DatasetHandler
15
  from .gradio_ui import create_gradio_interface
16
 
17
 
 
22
  Returns:
23
  Gradio Blocks instance
24
  """
25
+ # Create independent handler instances
26
+ dit_handler = AceStepHandler() # DiT handler
27
+ llm_handler = LLMHandler() # LM handler
28
+ dataset_handler = DatasetHandler() # Dataset handler
29
 
30
+ # Create Gradio interface with all handlers
31
+ demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler)
32
 
33
  return demo
34
 
acestep/dataset_handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Handler
3
+ Handles dataset import and exploration functionality
4
+ """
5
+ from typing import Optional, Tuple, Any, Dict
6
+
7
+
8
+ class DatasetHandler:
9
+ """Dataset Handler for Dataset Explorer functionality"""
10
+
11
+ def __init__(self):
12
+ """Initialize dataset handler"""
13
+ self.dataset = None
14
+ self.dataset_imported = False
15
+
16
+ def import_dataset(self, dataset_type: str) -> str:
17
+ """
18
+ Import dataset (temporarily disabled)
19
+
20
+ Args:
21
+ dataset_type: Type of dataset to import (e.g., "train", "test")
22
+
23
+ Returns:
24
+ Status message string
25
+ """
26
+ self.dataset_imported = False
27
+ return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
28
+
29
+ def get_item_data(self, *args, **kwargs) -> Tuple:
30
+ """
31
+ Get dataset item (temporarily disabled)
32
+
33
+ Returns:
34
+ Tuple of placeholder values matching the expected return format
35
+ """
36
+ return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
37
+
acestep/gradio_ui.py CHANGED
@@ -2,16 +2,19 @@
2
  Gradio UI Components Module
3
  Contains all Gradio interface component definitions and layouts
4
  """
 
5
  import gradio as gr
6
  from typing import Callable, Optional
7
 
8
 
9
- def create_gradio_interface(handler) -> gr.Blocks:
10
  """
11
  Create Gradio interface
12
 
13
  Args:
14
- handler: Business logic handler instance
 
 
15
 
16
  Returns:
17
  Gradio Blocks instance
@@ -42,21 +45,21 @@ def create_gradio_interface(handler) -> gr.Blocks:
42
  """)
43
 
44
  # Dataset Explorer Section
45
- dataset_section = create_dataset_section(handler)
46
 
47
  # Generation Section
48
- generation_section = create_generation_section(handler)
49
 
50
  # Results Section
51
- results_section = create_results_section(handler)
52
 
53
  # Connect event handlers
54
- setup_event_handlers(demo, handler, dataset_section, generation_section, results_section)
55
 
56
  return demo
57
 
58
 
59
- def create_dataset_section(handler) -> dict:
60
  """Create dataset explorer section"""
61
  with gr.Group():
62
  gr.HTML('<div class="section-header"><h3>📊 Dataset Explorer</h3></div>')
@@ -153,7 +156,7 @@ def create_dataset_section(handler) -> dict:
153
  }
154
 
155
 
156
- def create_generation_section(handler) -> dict:
157
  """Create generation section"""
158
  with gr.Group():
159
  gr.HTML('<div class="section-header"><h3>🎼 ACE-Step V1.5 Demo </h3></div>')
@@ -165,7 +168,7 @@ def create_generation_section(handler) -> dict:
165
  with gr.Column(scale=4):
166
  checkpoint_dropdown = gr.Dropdown(
167
  label="Checkpoint File",
168
- choices=handler.get_available_checkpoints(),
169
  value=None,
170
  info="Select a trained model checkpoint file (full path or filename)"
171
  )
@@ -174,7 +177,7 @@ def create_generation_section(handler) -> dict:
174
 
175
  with gr.Row():
176
  # Get available acestep-v15- model list
177
- available_models = handler.get_available_acestep_v15_models()
178
  default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
179
 
180
  config_path = gr.Dropdown(
@@ -192,7 +195,7 @@ def create_generation_section(handler) -> dict:
192
 
193
  with gr.Row():
194
  # Get available 5Hz LM model list
195
- available_lm_models = handler.get_available_5hz_lm_models()
196
  default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
197
 
198
  lm_model_path = gr.Dropdown(
@@ -216,7 +219,7 @@ def create_generation_section(handler) -> dict:
216
  info="Check to initialize 5Hz LM during service initialization",
217
  )
218
  # Auto-detect flash attention availability
219
- flash_attn_available = handler.is_flash_attention_available()
220
  use_flash_attention_checkbox = gr.Checkbox(
221
  label="Use Flash Attention",
222
  value=flash_attn_available,
@@ -565,7 +568,7 @@ def create_generation_section(handler) -> dict:
565
  }
566
 
567
 
568
- def create_results_section(handler) -> dict:
569
  """Create results display section"""
570
  with gr.Group():
571
  gr.HTML('<div class="section-header"><h3>🎧 Generated Results</h3></div>')
@@ -620,7 +623,7 @@ def create_results_section(handler) -> dict:
620
  }
621
 
622
 
623
- def setup_event_handlers(demo, handler, dataset_section, generation_section, results_section):
624
  """Setup event handlers connecting UI components and business logic"""
625
 
626
  def update_init_status(status_msg, enable_btn):
@@ -629,14 +632,14 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
629
 
630
  # Dataset handlers
631
  dataset_section["import_dataset_btn"].click(
632
- fn=handler.import_dataset,
633
  inputs=[dataset_section["dataset_type"]],
634
  outputs=[dataset_section["data_status"]]
635
  )
636
 
637
  # Service initialization - refresh checkpoints
638
  def refresh_checkpoints():
639
- choices = handler.get_available_checkpoints()
640
  return gr.update(choices=choices)
641
 
642
  generation_section["refresh_btn"].click(
@@ -698,12 +701,36 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
698
  # Service initialization
699
  def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
700
  """Wrapper for service initialization, returns status and button state"""
701
- status, enable = handler.initialize_service(
702
- checkpoint, config_path, device, init_llm, lm_model_path,
703
- backend=backend,
704
  use_flash_attention=use_flash_attention, compile_model=False,
705
  offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
706
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  return status, gr.update(interactive=enable)
708
 
709
  generation_section["init_btn"].click(
@@ -756,7 +783,7 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
756
  use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
757
  progress=gr.Progress(track_tqdm=True)
758
  ):
759
- return handler.generate_music(
760
  captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
761
  time_signature=time_signature, vocal_language=vocal_language,
762
  inference_steps=inference_steps, guidance_scale=guidance_scale,
@@ -820,7 +847,7 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
820
  # 5Hz LM generation (simplified version, can be extended as needed)
821
  def generate_lm_hints_wrapper(caption, lyrics, temperature, cfg_scale, negative_prompt):
822
  """Wrapper for 5Hz LM generation"""
823
- metadata, audio_codes, status = handler.generate_with_5hz_lm(caption, lyrics, temperature, cfg_scale, negative_prompt)
824
 
825
  # Extract metadata values and map to UI fields
826
  # Handle bpm
@@ -878,7 +905,7 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
878
  audio_codes_content: str = ""
879
  ) -> tuple:
880
  """Update instruction and UI visibility based on task type."""
881
- instruction = handler.generate_instruction(
882
  task_type=task_type_value,
883
  track_name=track_name_value,
884
  complete_track_classes=complete_track_classes_value
 
2
  Gradio UI Components Module
3
  Contains all Gradio interface component definitions and layouts
4
  """
5
+ import os
6
  import gradio as gr
7
  from typing import Callable, Optional
8
 
9
 
10
+ def create_gradio_interface(dit_handler, llm_handler, dataset_handler) -> gr.Blocks:
11
  """
12
  Create Gradio interface
13
 
14
  Args:
15
+ dit_handler: DiT handler instance
16
+ llm_handler: LM handler instance
17
+ dataset_handler: Dataset handler instance
18
 
19
  Returns:
20
  Gradio Blocks instance
 
45
  """)
46
 
47
  # Dataset Explorer Section
48
+ dataset_section = create_dataset_section(dataset_handler)
49
 
50
  # Generation Section
51
+ generation_section = create_generation_section(dit_handler, llm_handler)
52
 
53
  # Results Section
54
+ results_section = create_results_section(dit_handler)
55
 
56
  # Connect event handlers
57
+ setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
58
 
59
  return demo
60
 
61
 
62
+ def create_dataset_section(dataset_handler) -> dict:
63
  """Create dataset explorer section"""
64
  with gr.Group():
65
  gr.HTML('<div class="section-header"><h3>📊 Dataset Explorer</h3></div>')
 
156
  }
157
 
158
 
159
+ def create_generation_section(dit_handler, llm_handler) -> dict:
160
  """Create generation section"""
161
  with gr.Group():
162
  gr.HTML('<div class="section-header"><h3>🎼 ACE-Step V1.5 Demo </h3></div>')
 
168
  with gr.Column(scale=4):
169
  checkpoint_dropdown = gr.Dropdown(
170
  label="Checkpoint File",
171
+ choices=dit_handler.get_available_checkpoints(),
172
  value=None,
173
  info="Select a trained model checkpoint file (full path or filename)"
174
  )
 
177
 
178
  with gr.Row():
179
  # Get available acestep-v15- model list
180
+ available_models = dit_handler.get_available_acestep_v15_models()
181
  default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
182
 
183
  config_path = gr.Dropdown(
 
195
 
196
  with gr.Row():
197
  # Get available 5Hz LM model list
198
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
199
  default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
200
 
201
  lm_model_path = gr.Dropdown(
 
219
  info="Check to initialize 5Hz LM during service initialization",
220
  )
221
  # Auto-detect flash attention availability
222
+ flash_attn_available = dit_handler.is_flash_attention_available()
223
  use_flash_attention_checkbox = gr.Checkbox(
224
  label="Use Flash Attention",
225
  value=flash_attn_available,
 
568
  }
569
 
570
 
571
+ def create_results_section(dit_handler) -> dict:
572
  """Create results display section"""
573
  with gr.Group():
574
  gr.HTML('<div class="section-header"><h3>🎧 Generated Results</h3></div>')
 
623
  }
624
 
625
 
626
+ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
627
  """Setup event handlers connecting UI components and business logic"""
628
 
629
  def update_init_status(status_msg, enable_btn):
 
632
 
633
  # Dataset handlers
634
  dataset_section["import_dataset_btn"].click(
635
+ fn=dataset_handler.import_dataset,
636
  inputs=[dataset_section["dataset_type"]],
637
  outputs=[dataset_section["data_status"]]
638
  )
639
 
640
  # Service initialization - refresh checkpoints
641
  def refresh_checkpoints():
642
+ choices = dit_handler.get_available_checkpoints()
643
  return gr.update(choices=choices)
644
 
645
  generation_section["refresh_btn"].click(
 
701
  # Service initialization
702
  def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
703
  """Wrapper for service initialization, returns status and button state"""
704
+ # Initialize DiT handler
705
+ status, enable = dit_handler.initialize_service(
706
+ checkpoint, config_path, device,
707
  use_flash_attention=use_flash_attention, compile_model=False,
708
  offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
709
  )
710
+
711
+ # Initialize LM handler if requested
712
+ if init_llm:
713
+ # Get checkpoint directory
714
+ current_file = os.path.abspath(__file__)
715
+ project_root = os.path.dirname(os.path.dirname(current_file))
716
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
717
+
718
+ lm_status, lm_success = llm_handler.initialize(
719
+ checkpoint_dir=checkpoint_dir,
720
+ lm_model_path=lm_model_path,
721
+ backend=backend,
722
+ device=device,
723
+ offload_to_cpu=offload_to_cpu,
724
+ dtype=dit_handler.dtype
725
+ )
726
+
727
+ if lm_success:
728
+ status += f"\n{lm_status}"
729
+ else:
730
+ status += f"\n{lm_status}"
731
+ # Don't fail the entire initialization if LM fails, but log it
732
+ # Keep enable as is (DiT initialization result) even if LM fails
733
+
734
  return status, gr.update(interactive=enable)
735
 
736
  generation_section["init_btn"].click(
 
783
  use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
784
  progress=gr.Progress(track_tqdm=True)
785
  ):
786
+ return dit_handler.generate_music(
787
  captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
788
  time_signature=time_signature, vocal_language=vocal_language,
789
  inference_steps=inference_steps, guidance_scale=guidance_scale,
 
847
  # 5Hz LM generation (simplified version, can be extended as needed)
848
  def generate_lm_hints_wrapper(caption, lyrics, temperature, cfg_scale, negative_prompt):
849
  """Wrapper for 5Hz LM generation"""
850
+ metadata, audio_codes, status = llm_handler.generate_with_5hz_lm(caption, lyrics, temperature, cfg_scale, negative_prompt)
851
 
852
  # Extract metadata values and map to UI fields
853
  # Handle bpm
 
905
  audio_codes_content: str = ""
906
  ) -> tuple:
907
  """Update instruction and UI visibility based on task type."""
908
+ instruction = dit_handler.generate_instruction(
909
  task_type=task_type_value,
910
  track_name=track_name_value,
911
  complete_track_classes=complete_track_classes_value
acestep/handler.py CHANGED
@@ -61,19 +61,9 @@ class AceStepHandler:
61
  # Sample rate
62
  self.sample_rate = 48000
63
 
64
- # 5Hz LM related
65
- self.llm = None
66
- self.llm_tokenizer = None
67
- self.llm_initialized = False
68
- self.llm_backend = None
69
-
70
  # Reward model (temporarily disabled)
71
  self.reward_model = None
72
 
73
- # Dataset related (temporarily disabled)
74
- self.dataset = None
75
- self.dataset_imported = False
76
-
77
  # Batch size
78
  self.batch_size = 2
79
 
@@ -120,22 +110,6 @@ class AceStepHandler:
120
  models.sort()
121
  return models
122
 
123
- def get_available_5hz_lm_models(self) -> List[str]:
124
- """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
125
- current_file = os.path.abspath(__file__)
126
- project_root = os.path.dirname(os.path.dirname(current_file))
127
- checkpoint_dir = os.path.join(project_root, "checkpoints")
128
-
129
- models = []
130
- if os.path.exists(checkpoint_dir):
131
- for item in os.listdir(checkpoint_dir):
132
- item_path = os.path.join(checkpoint_dir, item)
133
- if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"):
134
- models.append(item)
135
-
136
- models.sort()
137
- return models
138
-
139
  def is_flash_attention_available(self) -> bool:
140
  """Check if flash attention is available on the system"""
141
  try:
@@ -149,9 +123,6 @@ class AceStepHandler:
149
  project_root: str,
150
  config_path: str,
151
  device: str = "auto",
152
- init_llm: bool = False,
153
- lm_model_path: str = "acestep-5Hz-lm-0.6B",
154
- backend: str = "vllm",
155
  use_flash_attention: bool = False,
156
  compile_model: bool = False,
157
  offload_to_cpu: bool = False,
@@ -159,15 +130,12 @@ class AceStepHandler:
159
  quantization: Optional[str] = None,
160
  ) -> Tuple[str, bool]:
161
  """
162
- Initialize model service
163
 
164
  Args:
165
  project_root: Project root path (may be checkpoints directory, will be handled automatically)
166
  config_path: Model config directory name (e.g., "acestep-v15-turbo")
167
  device: Device type
168
- init_llm: Whether to initialize 5Hz LM model
169
- lm_model_path: 5Hz LM model path
170
- backend: Backend for 5Hz LM model ("vllm" or "pt")
171
  use_flash_attention: Whether to use flash attention (requires flash_attn package)
172
  compile_model: Whether to use torch.compile to optimize the model
173
  offload_to_cpu: Whether to offload models to CPU when not in use
@@ -309,72 +277,14 @@ class AceStepHandler:
309
  self.text_encoder.eval()
310
  else:
311
  raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
312
-
313
- # 4. Load 5Hz LM model (optional, only if init_llm is True)
314
- if init_llm:
315
- full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
316
- if os.path.exists(full_lm_model_path):
317
- logger.info("loading 5Hz LM tokenizer...")
318
- start_time = time.time()
319
- llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
320
- logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
321
- self.llm_tokenizer = llm_tokenizer
322
-
323
- # Initialize based on user-selected backend
324
- if backend == "vllm":
325
- # Try to initialize with vllm
326
- status_msg = self._initialize_5hz_lm_vllm(full_lm_model_path)
327
- logger.info(f"5Hz LM status message: {status_msg}")
328
- # Check if initialization failed (status_msg starts with ❌)
329
- if status_msg.startswith("❌"):
330
- # vllm initialization failed, fallback to PyTorch
331
- if not self.llm_initialized:
332
- logger.warning("vllm initialization failed, falling back to PyTorch backend")
333
- try:
334
- self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
335
- if not self.offload_to_cpu:
336
- self.llm = self.llm.to(device).to(self.dtype)
337
- else:
338
- self.llm = self.llm.to("cpu").to(self.dtype)
339
- self.llm.eval()
340
- self.llm_backend = "pt"
341
- self.llm_initialized = True
342
- logger.info("5Hz LM initialized successfully using PyTorch backend (fallback)")
343
- except Exception as e:
344
- return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
345
- # If vllm initialization succeeded, self.llm_initialized should already be True
346
- else:
347
- # Use PyTorch backend (pt)
348
- try:
349
- self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
350
- if not self.offload_to_cpu:
351
- self.llm = self.llm.to(device).to(self.dtype)
352
- else:
353
- self.llm = self.llm.to("cpu").to(self.dtype)
354
- self.llm.eval()
355
- self.llm_backend = "pt"
356
- self.llm_initialized = True
357
- logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
358
- except Exception as e:
359
- return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
360
-
361
- else:
362
- # 5Hz LM path not found
363
- return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
364
 
365
  # Determine actual attention implementation used
366
  actual_attn = getattr(self.config, "_attn_implementation", "eager")
367
 
368
- status_msg = f"✅ Model initialized successfully on {device}\n" + status_msg
369
  status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
370
  status_msg += f"VAE: {vae_checkpoint_path}\n"
371
  status_msg += f"Text encoder: {text_encoder_path}\n"
372
- if init_llm and hasattr(self, 'llm') and self.llm is not None:
373
- backend_info = getattr(self, 'llm_backend', 'unknown')
374
- status_msg += f"5Hz LM model: {os.path.join(checkpoint_dir, lm_model_path)}\n"
375
- status_msg += f"5Hz LM backend: {backend_info}\n"
376
- else:
377
- status_msg += f"5Hz LM model: Not loaded (checkbox not selected)\n"
378
  status_msg += f"Dtype: {self.dtype}\n"
379
  status_msg += f"Attention: {actual_attn}\n"
380
  status_msg += f"Compiled: {compile_model}\n"
@@ -393,7 +303,7 @@ class AceStepHandler:
393
  Context manager to load a model to GPU and offload it back to CPU after use.
394
 
395
  Args:
396
- model_name: Name of the model to load ("text_encoder", "vae", "model", "llm")
397
  """
398
  if not self.offload_to_cpu:
399
  yield
@@ -418,11 +328,6 @@ class AceStepHandler:
418
  yield
419
  return
420
 
421
- # If model is LLM and using nanovllm, do not offload (it stays on GPU)
422
- if model_name == "llm" and getattr(self, "llm_type", None) == "nanovllm":
423
- yield
424
- return
425
-
426
  model = getattr(self, model_name, None)
427
  if model is None:
428
  yield
@@ -434,10 +339,6 @@ class AceStepHandler:
434
  if model_name == "vae":
435
  vae_dtype = torch.bfloat16 if self.device in ["cuda", "xpu"] else self.dtype
436
  model.to(self.device).to(vae_dtype)
437
- elif model_name == "llm" and hasattr(model, "to"):
438
- # Special handling for nanovllm LLM which might have custom to() method or structure
439
- # Assuming it has a .to() method based on our previous edits to nanovllm
440
- model.to(self.device)
441
  else:
442
  model.to(self.device).to(self.dtype)
443
 
@@ -454,10 +355,7 @@ class AceStepHandler:
454
  # Offload to CPU
455
  logger.info(f"Offloading {model_name} to CPU")
456
  start_time = time.time()
457
- if model_name == "llm" and hasattr(model, "to"):
458
- model.to("cpu")
459
- else:
460
- model.to("cpu")
461
 
462
  if model_name == "model" and hasattr(self, "silence_latent"):
463
  self.silence_latent = self.silence_latent.to("cpu")
@@ -467,318 +365,6 @@ class AceStepHandler:
467
  self.current_offload_cost += offload_time
468
  logger.info(f"Offloaded {model_name} to CPU in {offload_time:.4f}s")
469
 
470
- def import_dataset(self, dataset_type: str) -> str:
471
- """Import dataset (temporarily disabled)"""
472
- self.dataset_imported = False
473
- return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
474
-
475
- def get_item_data(self, *args, **kwargs):
476
- """Get dataset item (temporarily disabled)"""
477
- return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
478
-
479
- def get_gpu_memory_utilization(self, minimal_gpu: float = 8, min_ratio: float = 0.2, max_ratio: float = 0.9) -> float:
480
- """Get GPU memory utilization ratio"""
481
- try:
482
- device = torch.device("cuda:0")
483
- total_gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory
484
- allocated_mem_bytes = torch.cuda.memory_allocated(device)
485
- reserved_mem_bytes = torch.cuda.memory_reserved(device)
486
-
487
- total_gpu = total_gpu_mem_bytes / 1024**3
488
- low_gpu_memory_mode = False
489
- if total_gpu < minimal_gpu:
490
- minimal_gpu = 0.5 * total_gpu
491
- low_gpu_memory_mode = True
492
- allocated_gpu = allocated_mem_bytes / 1024**3
493
- reserved_gpu = reserved_mem_bytes / 1024**3
494
- available_gpu = total_gpu - reserved_gpu
495
-
496
- if available_gpu >= minimal_gpu:
497
- ratio = min(max_ratio, max(min_ratio, minimal_gpu / total_gpu))
498
- else:
499
- ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
500
-
501
- return ratio, low_gpu_memory_mode
502
- except Exception as e:
503
- return 0.9, low_gpu_memory_mode
504
-
505
- def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
506
- """Initialize 5Hz LM model"""
507
- if not torch.cuda.is_available():
508
- self.llm_initialized = False
509
- logger.error("CUDA is not available. Please check your GPU setup.")
510
- return "❌ CUDA is not available. Please check your GPU setup."
511
- try:
512
- from nanovllm import LLM, SamplingParams
513
- except ImportError:
514
- self.llm_initialized = False
515
- logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .")
516
- return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
517
-
518
- try:
519
- current_device = torch.cuda.current_device()
520
- device_name = torch.cuda.get_device_name(current_device)
521
-
522
- torch.cuda.empty_cache()
523
- gpu_memory_utilization, low_gpu_memory_mode = self.get_gpu_memory_utilization(
524
- minimal_gpu=8,
525
- min_ratio=0.2,
526
- max_ratio=0.9
527
- )
528
- if low_gpu_memory_mode:
529
- self.max_model_len = 2048
530
- else:
531
- self.max_model_len = 4096
532
-
533
- logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
534
- start_time = time.time()
535
- self.llm = LLM(
536
- model=model_path,
537
- enforce_eager=False,
538
- tensor_parallel_size=1,
539
- max_model_len=self.max_model_len,
540
- gpu_memory_utilization=gpu_memory_utilization,
541
- tokenizer=self.llm_tokenizer,
542
- )
543
- logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
544
- self.llm_initialized = True
545
- self.llm_backend = "vllm"
546
- return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
547
- except Exception as e:
548
- self.llm_initialized = False
549
- self.llm_type = None
550
- error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
551
- return error_msg
552
-
553
- def generate_with_5hz_lm_vllm(self, caption: str, lyrics: str, temperature: float = 0.6, cfg_scale: float = 1.0, negative_prompt: str = "NO USER INPUT") -> Tuple[Dict[str, Any], str, str]:
554
- try:
555
- from nanovllm import SamplingParams
556
-
557
- prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
558
-
559
- formatted_prompt = self.llm_tokenizer.apply_chat_template(
560
- [
561
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
562
- {"role": "user", "content": prompt}
563
- ],
564
- tokenize=False,
565
- add_generation_prompt=True,
566
- )
567
- logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
568
-
569
- sampling_params = SamplingParams(max_tokens=self.max_model_len-64, temperature=temperature, cfg_scale=cfg_scale)
570
- # Use CFG if cfg_scale > 1.0
571
- if cfg_scale > 1.0:
572
- # Build unconditional prompt (user input replaced with "NO USER INPUT")
573
- formatted_unconditional_prompt = self.lm_tokenizer.apply_chat_template(
574
- [
575
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
576
- {"role": "user", "content": negative_prompt}
577
- ],
578
- tokenize=False,
579
- add_generation_prompt=True,
580
- )
581
- outputs = self.llm.generate(
582
- [formatted_prompt],
583
- sampling_params,
584
- unconditional_prompts=[formatted_unconditional_prompt]
585
- )
586
- else:
587
- outputs = self.lm_model.generate([formatted_prompt], sampling_params)
588
- # Extract text from output - handle different output formats
589
- if isinstance(outputs, list) and len(outputs) > 0:
590
- if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
591
- output_text = outputs[0].outputs[0].text
592
- elif hasattr(outputs[0], 'text'):
593
- output_text = outputs[0].text
594
- elif isinstance(outputs[0], dict) and 'text' in outputs[0]:
595
- output_text = outputs[0]['text']
596
- else:
597
- output_text = str(outputs[0])
598
- else:
599
- output_text = str(outputs)
600
- metadata, audio_codes = self.parse_lm_output(output_text)
601
- print(f"[debug]output_text: {output_text}")
602
- codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
603
- return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
604
-
605
- except Exception as e:
606
- error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
607
- return {}, "", error_msg
608
-
609
- def generate_with_5hz_lm_pt(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
610
- try:
611
- prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
612
-
613
- formatted_prompt = self.llm_tokenizer.apply_chat_template(
614
- [
615
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
616
- {"role": "user", "content": prompt}
617
- ],
618
- tokenize=False,
619
- add_generation_prompt=True,
620
- )
621
-
622
- # Tokenize the prompt
623
- inputs = self.llm_tokenizer(
624
- formatted_prompt,
625
- return_tensors="pt",
626
- padding=False,
627
- truncation=True,
628
- )
629
-
630
- # Generate with the model
631
- with self._load_model_context("llm"):
632
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
633
-
634
- # Get max_new_tokens from model config or use a default
635
- max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
636
- if hasattr(self, 'max_model_len'):
637
- max_new_tokens = min(max_new_tokens, self.max_model_len)
638
-
639
- # Define custom streamer for tqdm
640
- class TqdmTokenStreamer(BaseStreamer):
641
- def __init__(self, total):
642
- self.pbar = tqdm(total=total, desc="Generating 5Hz tokens", unit="token", maxinterval=1)
643
-
644
- def put(self, value):
645
- # value is tensor of token ids
646
- if value.dim() > 1:
647
- num_tokens = value.numel()
648
- else:
649
- num_tokens = len(value)
650
- self.pbar.update(num_tokens)
651
-
652
- def end(self):
653
- self.pbar.close()
654
-
655
- streamer = TqdmTokenStreamer(total=max_new_tokens)
656
-
657
- with torch.no_grad():
658
- outputs = self.llm.generate(
659
- **inputs,
660
- max_new_tokens=max_new_tokens,
661
- temperature=temperature,
662
- do_sample=True if temperature > 0 else False,
663
- pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
664
- streamer=streamer,
665
- )
666
-
667
- # Decode the generated tokens
668
- # Only decode the newly generated tokens (skip the input prompt)
669
- generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
670
- output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
671
-
672
- metadata, audio_codes = self.parse_lm_output(output_text)
673
- codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
674
- return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
675
-
676
- except Exception as e:
677
- error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
678
- return {}, "", error_msg
679
-
680
- def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6, cfg_scale: float = 1.0, negative_prompt: str = "NO USER INPUT") -> Tuple[Dict[str, Any], str, str]:
681
- """Generate metadata and audio codes using 5Hz LM"""
682
- # Check if 5Hz LM is initialized
683
- if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
684
- debug_info = f"llm_initialized={getattr(self, 'llm_initialized', 'not set')}, "
685
- debug_info += f"has_llm={hasattr(self, 'llm')}, "
686
- debug_info += f"llm_is_none={getattr(self, 'llm', None) is None}, "
687
- debug_info += f"llm_backend={getattr(self, 'llm_backend', 'not set')}"
688
- return {}, "", f"❌ 5Hz LM not initialized. Please initialize it first. Debug: {debug_info}"
689
-
690
- if not hasattr(self, 'llm') or self.llm is None:
691
- return {}, "", "❌ 5Hz LM model not loaded. Please initialize it first."
692
-
693
- if not hasattr(self, 'llm_backend'):
694
- return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
695
-
696
- if self.llm_backend == "vllm":
697
- return self.generate_with_5hz_lm_vllm(caption, lyrics, temperature, cfg_scale, negative_prompt)
698
- else:
699
- return self.generate_with_5hz_lm_pt(caption, lyrics, temperature)
700
-
701
- def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
702
- """
703
- Parse LM output to extract metadata and audio codes.
704
-
705
- Expected format:
706
- <think>
707
- bpm: 73
708
- duration: 273
709
- genres: Chinese folk
710
- keyscale: G major
711
- timesignature: 4
712
- </think>
713
-
714
- <|audio_code_56535|><|audio_code_62918|>...
715
-
716
- Returns:
717
- Tuple of (metadata_dict, audio_codes_string)
718
- """
719
- debug_output_text = output_text.split("</think>")[0]
720
- logger.debug(f"Debug output text: {debug_output_text}")
721
- metadata = {}
722
- audio_codes = ""
723
-
724
- import re
725
-
726
- # Extract audio codes - find all <|audio_code_XXX|> patterns
727
- code_pattern = r'<\|audio_code_\d+\|>'
728
- code_matches = re.findall(code_pattern, output_text)
729
- if code_matches:
730
- audio_codes = "".join(code_matches)
731
-
732
- # Extract metadata from reasoning section
733
- # Try different reasoning tag patterns
734
- reasoning_patterns = [
735
- r'<think>(.*?)</think>',
736
- r'<think>(.*?)</think>',
737
- r'<reasoning>(.*?)</reasoning>',
738
- ]
739
-
740
- reasoning_text = None
741
- for pattern in reasoning_patterns:
742
- match = re.search(pattern, output_text, re.DOTALL)
743
- if match:
744
- reasoning_text = match.group(1).strip()
745
- break
746
-
747
- # If no reasoning tags found, try to parse metadata from the beginning of output
748
- if not reasoning_text:
749
- # Look for metadata lines before audio codes
750
- lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
751
- reasoning_text = lines_before_codes.strip()
752
-
753
- # Parse metadata fields
754
- if reasoning_text:
755
- for line in reasoning_text.split('\n'):
756
- line = line.strip()
757
- if ':' in line and not line.startswith('<'):
758
- parts = line.split(':', 1)
759
- if len(parts) == 2:
760
- key = parts[0].strip().lower()
761
- value = parts[1].strip()
762
-
763
- if key == 'bpm':
764
- try:
765
- metadata['bpm'] = int(value)
766
- except:
767
- metadata['bpm'] = value
768
- elif key == 'duration':
769
- try:
770
- metadata['duration'] = int(value)
771
- except:
772
- metadata['duration'] = value
773
- elif key == 'genres':
774
- metadata['genres'] = value
775
- elif key == 'keyscale':
776
- metadata['keyscale'] = value
777
- elif key == 'timesignature':
778
- metadata['timesignature'] = value
779
-
780
- return metadata, audio_codes
781
-
782
  def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
783
  """Process target audio"""
784
  if audio_file is None:
@@ -837,13 +423,13 @@ class AceStepHandler:
837
  detokenizer = self.model.detokenizer
838
 
839
  num_quantizers = getattr(quantizer, "num_quantizers", 1)
840
- indices = torch.tensor(code_ids, device=self.device, dtype=torch.long).unsqueeze(0) # [1, T_5Hz]
 
 
 
841
 
842
- # Expand to include quantizer dimension: [1, T_5Hz, num_quantizers]
843
- if indices.dim() == 2:
844
- indices = indices.unsqueeze(-1).expand(-1, -1, num_quantizers)
845
- print(indices.shape)
846
- # Get quantized representation from indices: [1, T_5Hz, dim]
847
  quantized = quantizer.get_output_from_indices(indices)
848
  if quantized.dtype != self.dtype:
849
  quantized = quantized.to(self.dtype)
 
61
  # Sample rate
62
  self.sample_rate = 48000
63
 
 
 
 
 
 
 
64
  # Reward model (temporarily disabled)
65
  self.reward_model = None
66
 
 
 
 
 
67
  # Batch size
68
  self.batch_size = 2
69
 
 
110
  models.sort()
111
  return models
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def is_flash_attention_available(self) -> bool:
114
  """Check if flash attention is available on the system"""
115
  try:
 
123
  project_root: str,
124
  config_path: str,
125
  device: str = "auto",
 
 
 
126
  use_flash_attention: bool = False,
127
  compile_model: bool = False,
128
  offload_to_cpu: bool = False,
 
130
  quantization: Optional[str] = None,
131
  ) -> Tuple[str, bool]:
132
  """
133
+ Initialize DiT model service
134
 
135
  Args:
136
  project_root: Project root path (may be checkpoints directory, will be handled automatically)
137
  config_path: Model config directory name (e.g., "acestep-v15-turbo")
138
  device: Device type
 
 
 
139
  use_flash_attention: Whether to use flash attention (requires flash_attn package)
140
  compile_model: Whether to use torch.compile to optimize the model
141
  offload_to_cpu: Whether to offload models to CPU when not in use
 
277
  self.text_encoder.eval()
278
  else:
279
  raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  # Determine actual attention implementation used
282
  actual_attn = getattr(self.config, "_attn_implementation", "eager")
283
 
284
+ status_msg = f"✅ Model initialized successfully on {device}\n"
285
  status_msg += f"Main model: {acestep_v15_checkpoint_path}\n"
286
  status_msg += f"VAE: {vae_checkpoint_path}\n"
287
  status_msg += f"Text encoder: {text_encoder_path}\n"
 
 
 
 
 
 
288
  status_msg += f"Dtype: {self.dtype}\n"
289
  status_msg += f"Attention: {actual_attn}\n"
290
  status_msg += f"Compiled: {compile_model}\n"
 
303
  Context manager to load a model to GPU and offload it back to CPU after use.
304
 
305
  Args:
306
+ model_name: Name of the model to load ("text_encoder", "vae", "model")
307
  """
308
  if not self.offload_to_cpu:
309
  yield
 
328
  yield
329
  return
330
 
 
 
 
 
 
331
  model = getattr(self, model_name, None)
332
  if model is None:
333
  yield
 
339
  if model_name == "vae":
340
  vae_dtype = torch.bfloat16 if self.device in ["cuda", "xpu"] else self.dtype
341
  model.to(self.device).to(vae_dtype)
 
 
 
 
342
  else:
343
  model.to(self.device).to(self.dtype)
344
 
 
355
  # Offload to CPU
356
  logger.info(f"Offloading {model_name} to CPU")
357
  start_time = time.time()
358
+ model.to("cpu")
 
 
 
359
 
360
  if model_name == "model" and hasattr(self, "silence_latent"):
361
  self.silence_latent = self.silence_latent.to("cpu")
 
365
  self.current_offload_cost += offload_time
366
  logger.info(f"Offloaded {model_name} to CPU in {offload_time:.4f}s")
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  def process_target_audio(self, audio_file) -> Optional[torch.Tensor]:
369
  """Process target audio"""
370
  if audio_file is None:
 
423
  detokenizer = self.model.detokenizer
424
 
425
  num_quantizers = getattr(quantizer, "num_quantizers", 1)
426
+ # Create indices tensor: [T_5Hz]
427
+ indices = torch.tensor(code_ids, device=self.device, dtype=torch.long) # [T_5Hz]
428
+
429
+ indices = indices.unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1]
430
 
431
+ # Get quantized representation from indices
432
+ # The quantizer expects [batch, T_5Hz] format and handles quantizer dimension internally
 
 
 
433
  quantized = quantizer.get_output_from_indices(indices)
434
  if quantized.dtype != self.dtype:
435
  quantized = quantized.to(self.dtype)
acestep/llm_inference.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 5Hz LM (Language Model) Handler
3
+ Handles all LM-related operations including initialization and generation
4
+ """
5
+ import os
6
+ import traceback
7
+ import time
8
+ from typing import Optional, Dict, Any, Tuple, List
9
+ from contextlib import contextmanager
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+ from loguru import logger
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ from transformers.generation.streamers import BaseStreamer
16
+
17
+
18
+ class LLMHandler:
19
+ """5Hz LM Handler for audio code generation"""
20
+
21
+ def __init__(self):
22
+ """Initialize LLMHandler with default values"""
23
+ self.llm = None
24
+ self.llm_tokenizer = None
25
+ self.llm_initialized = False
26
+ self.llm_backend = None
27
+ self.max_model_len = 4096
28
+ self.device = "cpu"
29
+ self.dtype = torch.float32
30
+ self.offload_to_cpu = False
31
+
32
+ def get_available_5hz_lm_models(self) -> List[str]:
33
+ """Scan and return all model directory names starting with 'acestep-5Hz-lm-'"""
34
+ current_file = os.path.abspath(__file__)
35
+ project_root = os.path.dirname(os.path.dirname(current_file))
36
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
37
+
38
+ models = []
39
+ if os.path.exists(checkpoint_dir):
40
+ for item in os.listdir(checkpoint_dir):
41
+ item_path = os.path.join(checkpoint_dir, item)
42
+ if os.path.isdir(item_path) and item.startswith("acestep-5Hz-lm-"):
43
+ models.append(item)
44
+
45
+ models.sort()
46
+ return models
47
+
48
+ def get_gpu_memory_utilization(self, minimal_gpu: float = 8, min_ratio: float = 0.2, max_ratio: float = 0.9) -> Tuple[float, bool]:
49
+ """Get GPU memory utilization ratio"""
50
+ try:
51
+ device = torch.device("cuda:0")
52
+ total_gpu_mem_bytes = torch.cuda.get_device_properties(device).total_memory
53
+ allocated_mem_bytes = torch.cuda.memory_allocated(device)
54
+ reserved_mem_bytes = torch.cuda.memory_reserved(device)
55
+
56
+ total_gpu = total_gpu_mem_bytes / 1024**3
57
+ low_gpu_memory_mode = False
58
+ if total_gpu < minimal_gpu:
59
+ minimal_gpu = 0.5 * total_gpu
60
+ low_gpu_memory_mode = True
61
+ allocated_gpu = allocated_mem_bytes / 1024**3
62
+ reserved_gpu = reserved_mem_bytes / 1024**3
63
+ available_gpu = total_gpu - reserved_gpu
64
+
65
+ if available_gpu >= minimal_gpu:
66
+ ratio = min(max_ratio, max(min_ratio, minimal_gpu / total_gpu))
67
+ else:
68
+ ratio = min(max_ratio, max(min_ratio, (available_gpu * 0.8) / total_gpu))
69
+
70
+ return ratio, low_gpu_memory_mode
71
+ except Exception as e:
72
+ return 0.9, False
73
+
74
+ def initialize(
75
+ self,
76
+ checkpoint_dir: str,
77
+ lm_model_path: str,
78
+ backend: str = "vllm",
79
+ device: str = "auto",
80
+ offload_to_cpu: bool = False,
81
+ dtype: Optional[torch.dtype] = None,
82
+ ) -> Tuple[str, bool]:
83
+ """
84
+ Initialize 5Hz LM model
85
+
86
+ Args:
87
+ checkpoint_dir: Checkpoint directory path
88
+ lm_model_path: LM model path (relative to checkpoint_dir)
89
+ backend: Backend type ("vllm" or "pt")
90
+ device: Device type ("auto", "cuda", or "cpu")
91
+ offload_to_cpu: Whether to offload to CPU
92
+ dtype: Data type (if None, auto-detect based on device)
93
+
94
+ Returns:
95
+ (status_message, success)
96
+ """
97
+ try:
98
+ if device == "auto":
99
+ device = "cuda" if torch.cuda.is_available() else "cpu"
100
+
101
+ self.device = device
102
+ self.offload_to_cpu = offload_to_cpu
103
+ # Set dtype based on device: bfloat16 for cuda, float32 for cpu
104
+ if dtype is None:
105
+ self.dtype = torch.bfloat16 if device in ["cuda", "xpu"] else torch.float32
106
+ else:
107
+ self.dtype = dtype
108
+
109
+ full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
110
+ if not os.path.exists(full_lm_model_path):
111
+ return f"❌ 5Hz LM model not found at {full_lm_model_path}", False
112
+
113
+ logger.info("loading 5Hz LM tokenizer...")
114
+ start_time = time.time()
115
+ llm_tokenizer = AutoTokenizer.from_pretrained(full_lm_model_path, use_fast=True)
116
+ logger.info(f"5Hz LM tokenizer loaded successfully in {time.time() - start_time:.2f} seconds")
117
+ self.llm_tokenizer = llm_tokenizer
118
+
119
+ # Initialize based on user-selected backend
120
+ if backend == "vllm":
121
+ # Try to initialize with vllm
122
+ status_msg = self._initialize_5hz_lm_vllm(full_lm_model_path)
123
+ logger.info(f"5Hz LM status message: {status_msg}")
124
+ # Check if initialization failed (status_msg starts with ❌)
125
+ if status_msg.startswith("❌"):
126
+ # vllm initialization failed, fallback to PyTorch
127
+ if not self.llm_initialized:
128
+ logger.warning("vllm initialization failed, falling back to PyTorch backend")
129
+ try:
130
+ self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
131
+ if not self.offload_to_cpu:
132
+ self.llm = self.llm.to(device).to(self.dtype)
133
+ else:
134
+ self.llm = self.llm.to("cpu").to(self.dtype)
135
+ self.llm.eval()
136
+ self.llm_backend = "pt"
137
+ self.llm_initialized = True
138
+ logger.info("5Hz LM initialized successfully using PyTorch backend (fallback)")
139
+ status_msg = f"✅ 5Hz LM initialized successfully (PyTorch fallback)\nModel: {full_lm_model_path}\nBackend: PyTorch"
140
+ except Exception as e:
141
+ return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
142
+ # If vllm initialization succeeded, self.llm_initialized should already be True
143
+ else:
144
+ # Use PyTorch backend (pt)
145
+ try:
146
+ self.llm = AutoModelForCausalLM.from_pretrained(full_lm_model_path, trust_remote_code=True)
147
+ if not self.offload_to_cpu:
148
+ self.llm = self.llm.to(device).to(self.dtype)
149
+ else:
150
+ self.llm = self.llm.to("cpu").to(self.dtype)
151
+ self.llm.eval()
152
+ self.llm_backend = "pt"
153
+ self.llm_initialized = True
154
+ logger.info(f"5Hz LM initialized successfully using PyTorch backend on {device}")
155
+ status_msg = f"✅ 5Hz LM initialized successfully\nModel: {full_lm_model_path}\nBackend: PyTorch\nDevice: {device}"
156
+ except Exception as e:
157
+ return f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", False
158
+
159
+ return status_msg, True
160
+
161
+ except Exception as e:
162
+ error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
163
+ return error_msg, False
164
+
165
+ def _initialize_5hz_lm_vllm(self, model_path: str) -> str:
166
+ """Initialize 5Hz LM model using vllm backend"""
167
+ if not torch.cuda.is_available():
168
+ self.llm_initialized = False
169
+ logger.error("CUDA is not available. Please check your GPU setup.")
170
+ return "❌ CUDA is not available. Please check your GPU setup."
171
+ try:
172
+ from nanovllm import LLM, SamplingParams
173
+ except ImportError:
174
+ self.llm_initialized = False
175
+ logger.error("nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install .")
176
+ return "❌ nano-vllm is not installed. Please install it using 'cd acestep/third_parts/nano-vllm && pip install ."
177
+
178
+ try:
179
+ current_device = torch.cuda.current_device()
180
+ device_name = torch.cuda.get_device_name(current_device)
181
+
182
+ torch.cuda.empty_cache()
183
+ gpu_memory_utilization, low_gpu_memory_mode = self.get_gpu_memory_utilization(
184
+ minimal_gpu=8,
185
+ min_ratio=0.2,
186
+ max_ratio=0.9
187
+ )
188
+ if low_gpu_memory_mode:
189
+ self.max_model_len = 2048
190
+ else:
191
+ self.max_model_len = 4096
192
+
193
+ logger.info(f"Initializing 5Hz LM with model: {model_path}, enforce_eager: False, tensor_parallel_size: 1, max_model_len: {self.max_model_len}, gpu_memory_utilization: {gpu_memory_utilization}")
194
+ start_time = time.time()
195
+ self.llm = LLM(
196
+ model=model_path,
197
+ enforce_eager=False,
198
+ tensor_parallel_size=1,
199
+ max_model_len=self.max_model_len,
200
+ gpu_memory_utilization=gpu_memory_utilization,
201
+ tokenizer=self.llm_tokenizer,
202
+ )
203
+ logger.info(f"5Hz LM initialized successfully in {time.time() - start_time:.2f} seconds")
204
+ self.llm_initialized = True
205
+ self.llm_backend = "vllm"
206
+ return f"✅ 5Hz LM initialized successfully\nModel: {model_path}\nDevice: {device_name}\nGPU Memory Utilization: {gpu_memory_utilization:.2f}"
207
+ except Exception as e:
208
+ self.llm_initialized = False
209
+ error_msg = f"❌ Error initializing 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
210
+ return error_msg
211
+
212
+ def generate_with_5hz_lm_vllm(self, caption: str, lyrics: str, temperature: float = 0.6, cfg_scale: float = 1.0, negative_prompt: str = "NO USER INPUT") -> Tuple[Dict[str, Any], str, str]:
213
+ """Generate metadata and audio codes using 5Hz LM with vllm backend"""
214
+ try:
215
+ from nanovllm import SamplingParams
216
+
217
+ prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
218
+
219
+ formatted_prompt = self.llm_tokenizer.apply_chat_template(
220
+ [
221
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
222
+ {"role": "user", "content": prompt}
223
+ ],
224
+ tokenize=False,
225
+ add_generation_prompt=True,
226
+ )
227
+ logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
228
+
229
+ sampling_params = SamplingParams(max_tokens=self.max_model_len-64, temperature=temperature, cfg_scale=cfg_scale)
230
+ # Use CFG if cfg_scale > 1.0
231
+ if cfg_scale > 1.0:
232
+ # Build unconditional prompt (user input replaced with "NO USER INPUT")
233
+ formatted_unconditional_prompt = self.llm_tokenizer.apply_chat_template(
234
+ [
235
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
236
+ {"role": "user", "content": negative_prompt}
237
+ ],
238
+ tokenize=False,
239
+ add_generation_prompt=True,
240
+ )
241
+ outputs = self.llm.generate(
242
+ [formatted_prompt],
243
+ sampling_params,
244
+ unconditional_prompts=[formatted_unconditional_prompt]
245
+ )
246
+ else:
247
+ outputs = self.llm.generate([formatted_prompt], sampling_params)
248
+ # Extract text from output - handle different output formats
249
+ if isinstance(outputs, list) and len(outputs) > 0:
250
+ if hasattr(outputs[0], 'outputs') and len(outputs[0].outputs) > 0:
251
+ output_text = outputs[0].outputs[0].text
252
+ elif hasattr(outputs[0], 'text'):
253
+ output_text = outputs[0].text
254
+ elif isinstance(outputs[0], dict) and 'text' in outputs[0]:
255
+ output_text = outputs[0]['text']
256
+ else:
257
+ output_text = str(outputs[0])
258
+ else:
259
+ output_text = str(outputs)
260
+ metadata, audio_codes = self.parse_lm_output(output_text)
261
+ print(f"[debug]output_text: {output_text}")
262
+ codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
263
+ return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
264
+
265
+ except Exception as e:
266
+ error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
267
+ return {}, "", error_msg
268
+
269
+ def generate_with_5hz_lm_pt(self, caption: str, lyrics: str, temperature: float = 0.6) -> Tuple[Dict[str, Any], str, str]:
270
+ """Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
271
+ try:
272
+ prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
273
+
274
+ formatted_prompt = self.llm_tokenizer.apply_chat_template(
275
+ [
276
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
277
+ {"role": "user", "content": prompt}
278
+ ],
279
+ tokenize=False,
280
+ add_generation_prompt=True,
281
+ )
282
+
283
+ # Tokenize the prompt
284
+ inputs = self.llm_tokenizer(
285
+ formatted_prompt,
286
+ return_tensors="pt",
287
+ padding=False,
288
+ truncation=True,
289
+ )
290
+
291
+ # Generate with the model
292
+ with self._load_model_context():
293
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
294
+
295
+ # Get max_new_tokens from model config or use a default
296
+ max_new_tokens = getattr(self.llm.config, 'max_new_tokens', 4096)
297
+ if hasattr(self, 'max_model_len'):
298
+ max_new_tokens = min(max_new_tokens, self.max_model_len)
299
+
300
+ # Define custom streamer for tqdm
301
+ class TqdmTokenStreamer(BaseStreamer):
302
+ def __init__(self, total):
303
+ self.pbar = tqdm(total=total, desc="Generating 5Hz tokens", unit="token", maxinterval=1)
304
+
305
+ def put(self, value):
306
+ # value is tensor of token ids
307
+ if value.dim() > 1:
308
+ num_tokens = value.numel()
309
+ else:
310
+ num_tokens = len(value)
311
+ self.pbar.update(num_tokens)
312
+
313
+ def end(self):
314
+ self.pbar.close()
315
+
316
+ streamer = TqdmTokenStreamer(total=max_new_tokens)
317
+
318
+ with torch.no_grad():
319
+ outputs = self.llm.generate(
320
+ **inputs,
321
+ max_new_tokens=max_new_tokens,
322
+ temperature=temperature,
323
+ do_sample=True if temperature > 0 else False,
324
+ pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
325
+ streamer=streamer,
326
+ )
327
+
328
+ # Decode the generated tokens
329
+ # Only decode the newly generated tokens (skip the input prompt)
330
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
331
+ output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
332
+
333
+ metadata, audio_codes = self.parse_lm_output(output_text)
334
+ codes_count = len(audio_codes.split('<|audio_code_')) - 1 if audio_codes else 0
335
+ return metadata, audio_codes, f"✅ Generated successfully\nOutput length: {len(output_text)} chars\nCodes count: {codes_count}"
336
+
337
+ except Exception as e:
338
+ error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
339
+ return {}, "", error_msg
340
+
341
+ def generate_with_5hz_lm(self, caption: str, lyrics: str, temperature: float = 0.6, cfg_scale: float = 1.0, negative_prompt: str = "NO USER INPUT") -> Tuple[Dict[str, Any], str, str]:
342
+ """Generate metadata and audio codes using 5Hz LM"""
343
+ # Check if 5Hz LM is initialized
344
+ if not hasattr(self, 'llm_initialized') or not self.llm_initialized:
345
+ debug_info = f"llm_initialized={getattr(self, 'llm_initialized', 'not set')}, "
346
+ debug_info += f"has_llm={hasattr(self, 'llm')}, "
347
+ debug_info += f"llm_is_none={getattr(self, 'llm', None) is None}, "
348
+ debug_info += f"llm_backend={getattr(self, 'llm_backend', 'not set')}"
349
+ return {}, "", f"❌ 5Hz LM not initialized. Please initialize it first. Debug: {debug_info}"
350
+
351
+ if not hasattr(self, 'llm') or self.llm is None:
352
+ return {}, "", "❌ 5Hz LM model not loaded. Please initialize it first."
353
+
354
+ if not hasattr(self, 'llm_backend'):
355
+ return {}, "", "❌ 5Hz LM backend not set. Please initialize it first."
356
+
357
+ if self.llm_backend == "vllm":
358
+ return self.generate_with_5hz_lm_vllm(caption, lyrics, temperature, cfg_scale, negative_prompt)
359
+ else:
360
+ return self.generate_with_5hz_lm_pt(caption, lyrics, temperature)
361
+
362
+ def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
363
+ """
364
+ Parse LM output to extract metadata and audio codes.
365
+
366
+ Expected format:
367
+ <think>
368
+ bpm: 73
369
+ duration: 273
370
+ genres: Chinese folk
371
+ keyscale: G major
372
+ timesignature: 4
373
+ </think>
374
+
375
+ <|audio_code_56535|><|audio_code_62918|>...
376
+
377
+ Returns:
378
+ Tuple of (metadata_dict, audio_codes_string)
379
+ """
380
+ debug_output_text = output_text.split("</think>")[0]
381
+ logger.debug(f"Debug output text: {debug_output_text}")
382
+ metadata = {}
383
+ audio_codes = ""
384
+
385
+ import re
386
+
387
+ # Extract audio codes - find all <|audio_code_XXX|> patterns
388
+ code_pattern = r'<\|audio_code_\d+\|>'
389
+ code_matches = re.findall(code_pattern, output_text)
390
+ if code_matches:
391
+ audio_codes = "".join(code_matches)
392
+
393
+ # Extract metadata from reasoning section
394
+ # Try different reasoning tag patterns
395
+ reasoning_patterns = [
396
+ r'<think>(.*?)</think>',
397
+ r'<think>(.*?)</think>',
398
+ r'<reasoning>(.*?)</reasoning>',
399
+ ]
400
+
401
+ reasoning_text = None
402
+ for pattern in reasoning_patterns:
403
+ match = re.search(pattern, output_text, re.DOTALL)
404
+ if match:
405
+ reasoning_text = match.group(1).strip()
406
+ break
407
+
408
+ # If no reasoning tags found, try to parse metadata from the beginning of output
409
+ if not reasoning_text:
410
+ # Look for metadata lines before audio codes
411
+ lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
412
+ reasoning_text = lines_before_codes.strip()
413
+
414
+ # Parse metadata fields
415
+ if reasoning_text:
416
+ for line in reasoning_text.split('\n'):
417
+ line = line.strip()
418
+ if ':' in line and not line.startswith('<'):
419
+ parts = line.split(':', 1)
420
+ if len(parts) == 2:
421
+ key = parts[0].strip().lower()
422
+ value = parts[1].strip()
423
+
424
+ if key == 'bpm':
425
+ try:
426
+ metadata['bpm'] = int(value)
427
+ except:
428
+ metadata['bpm'] = value
429
+ elif key == 'duration':
430
+ try:
431
+ metadata['duration'] = int(value)
432
+ except:
433
+ metadata['duration'] = value
434
+ elif key == 'genres':
435
+ metadata['genres'] = value
436
+ elif key == 'keyscale':
437
+ metadata['keyscale'] = value
438
+ elif key == 'timesignature':
439
+ metadata['timesignature'] = value
440
+
441
+ return metadata, audio_codes
442
+
443
+ @contextmanager
444
+ def _load_model_context(self):
445
+ """
446
+ Context manager to load a model to GPU and offload it back to CPU after use.
447
+ Only used for PyTorch backend when offload_to_cpu is True.
448
+ """
449
+ if not self.offload_to_cpu:
450
+ yield
451
+ return
452
+
453
+ # If using nanovllm, do not offload (it stays on GPU)
454
+ if self.llm_backend == "vllm":
455
+ yield
456
+ return
457
+
458
+ model = self.llm
459
+ if model is None:
460
+ yield
461
+ return
462
+
463
+ # Load to GPU
464
+ logger.info(f"Loading LLM to {self.device}")
465
+ start_time = time.time()
466
+ if hasattr(model, "to"):
467
+ model.to(self.device).to(self.dtype)
468
+ load_time = time.time() - start_time
469
+ logger.info(f"Loaded LLM to {self.device} in {load_time:.4f}s")
470
+
471
+ try:
472
+ yield
473
+ finally:
474
+ # Offload to CPU
475
+ logger.info(f"Offloading LLM to CPU")
476
+ start_time = time.time()
477
+ if hasattr(model, "to"):
478
+ model.to("cpu")
479
+ torch.cuda.empty_cache()
480
+ offload_time = time.time() - start_time
481
+ logger.info(f"Offloaded LLM to CPU in {offload_time:.4f}s")
482
+
requirements.txt CHANGED
@@ -5,4 +5,5 @@ gradio
5
  soundfile
6
  loguru
7
  einops
8
- accelerator
 
 
5
  soundfile
6
  loguru
7
  einops
8
+ accelerator
9
+ vector-quantize-pytorch