ACE-Step Custom commited on
Commit
78910e3
·
1 Parent(s): ee19acb

Fix syntax error: close unterminated docstring in _load_models

Browse files
Files changed (1) hide show
  1. src/ace_step_engine.py +9 -6
src/ace_step_engine.py CHANGED
@@ -92,7 +92,7 @@ class ACEStepEngine:
92
  raise
93
 
94
  def _load_models(self):
95
- """Initialize s_dir = get_checkpoints_dir(self.config.get("checkpoint_dir")
96
  try:
97
  if not ACE_STEP_AVAILABLE:
98
  raise RuntimeError("ACE-Step 1.5 not available")
@@ -101,15 +101,18 @@ class ACEStepEngine:
101
  dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo")
102
  lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B")
103
 
 
 
 
104
  # Get project root
105
  project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
106
 
107
- logger.info(f"Instr(checkpoints_dir / dit_model_path) handler with model: {dit_model_path}")
108
 
109
  # Initialize DiT handler (handles main diffusion model, VAE, text encoder)
110
  status_dit, success_dit = self.dit_handler.initialize_service(
111
  project_root=project_root,
112
- config_path=dit_model_path,
113
  device="auto",
114
  use_flash_attention=False,
115
  compile_model=False,
@@ -119,13 +122,13 @@ class ACEStepEngine:
119
  if not success_dit:
120
  raise RuntimeError(f"Failed to initialize DiT: {status_dit}")
121
 
122
- logger.info(f" DiT initialized: {status_dit}")
123
- str(checkpoints_dir)
124
  # Initialize LLM handler (handles 5Hz Language Model)
125
  logger.info(f"Initializing LLM handler with model: {lm_model_path}")
126
 
127
  status_llm, success_llm = self.llm_handler.initialize(
128
- checkpoint_dir=checkpoint_dir,
129
  lm_model_path=lm_model_path,
130
  backend="pt", # Use PyTorch backend for compatibility
131
  device="auto",
 
92
  raise
93
 
94
  def _load_models(self):
95
+ """Initialize and load ACE-Step models."""
96
  try:
97
  if not ACE_STEP_AVAILABLE:
98
  raise RuntimeError("ACE-Step 1.5 not available")
 
101
  dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo")
102
  lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B")
103
 
104
+ # Get checkpoints directory using helper function
105
+ checkpoints_dir = get_checkpoints_dir(checkpoint_dir)
106
+
107
  # Get project root
108
  project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
109
 
110
+ logger.info(f"Initializing DiT handler with model: {dit_model_path}")
111
 
112
  # Initialize DiT handler (handles main diffusion model, VAE, text encoder)
113
  status_dit, success_dit = self.dit_handler.initialize_service(
114
  project_root=project_root,
115
+ config_path=str(checkpoints_dir / dit_model_path),
116
  device="auto",
117
  use_flash_attention=False,
118
  compile_model=False,
 
122
  if not success_dit:
123
  raise RuntimeError(f"Failed to initialize DiT: {status_dit}")
124
 
125
+ logger.info(f" DiT initialized: {status_dit}")
126
+
127
  # Initialize LLM handler (handles 5Hz Language Model)
128
  logger.info(f"Initializing LLM handler with model: {lm_model_path}")
129
 
130
  status_llm, success_llm = self.llm_handler.initialize(
131
+ checkpoint_dir=str(checkpoints_dir),
132
  lm_model_path=lm_model_path,
133
  backend="pt", # Use PyTorch backend for compatibility
134
  device="auto",