Spaces:
Running on Zero
Running on Zero
ACE-Step Custom commited on
Commit ·
78910e3
1
Parent(s): ee19acb
Fix syntax error: close unterminated docstring in _load_models
Browse files- src/ace_step_engine.py +9 -6
src/ace_step_engine.py
CHANGED
|
@@ -92,7 +92,7 @@ class ACEStepEngine:
|
|
| 92 |
raise
|
| 93 |
|
| 94 |
def _load_models(self):
|
| 95 |
-
"""Initialize
|
| 96 |
try:
|
| 97 |
if not ACE_STEP_AVAILABLE:
|
| 98 |
raise RuntimeError("ACE-Step 1.5 not available")
|
|
@@ -101,15 +101,18 @@ class ACEStepEngine:
|
|
| 101 |
dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo")
|
| 102 |
lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B")
|
| 103 |
|
|
|
|
|
|
|
|
|
|
| 104 |
# Get project root
|
| 105 |
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 106 |
|
| 107 |
-
logger.info(f"
|
| 108 |
|
| 109 |
# Initialize DiT handler (handles main diffusion model, VAE, text encoder)
|
| 110 |
status_dit, success_dit = self.dit_handler.initialize_service(
|
| 111 |
project_root=project_root,
|
| 112 |
-
config_path=dit_model_path,
|
| 113 |
device="auto",
|
| 114 |
use_flash_attention=False,
|
| 115 |
compile_model=False,
|
|
@@ -119,13 +122,13 @@ class ACEStepEngine:
|
|
| 119 |
if not success_dit:
|
| 120 |
raise RuntimeError(f"Failed to initialize DiT: {status_dit}")
|
| 121 |
|
| 122 |
-
logger.info(f" DiT initialized: {status_dit}")
|
| 123 |
-
|
| 124 |
# Initialize LLM handler (handles 5Hz Language Model)
|
| 125 |
logger.info(f"Initializing LLM handler with model: {lm_model_path}")
|
| 126 |
|
| 127 |
status_llm, success_llm = self.llm_handler.initialize(
|
| 128 |
-
checkpoint_dir=
|
| 129 |
lm_model_path=lm_model_path,
|
| 130 |
backend="pt", # Use PyTorch backend for compatibility
|
| 131 |
device="auto",
|
|
|
|
| 92 |
raise
|
| 93 |
|
| 94 |
def _load_models(self):
|
| 95 |
+
"""Initialize and load ACE-Step models."""
|
| 96 |
try:
|
| 97 |
if not ACE_STEP_AVAILABLE:
|
| 98 |
raise RuntimeError("ACE-Step 1.5 not available")
|
|
|
|
| 101 |
dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo")
|
| 102 |
lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B")
|
| 103 |
|
| 104 |
+
# Get checkpoints directory using helper function
|
| 105 |
+
checkpoints_dir = get_checkpoints_dir(checkpoint_dir)
|
| 106 |
+
|
| 107 |
# Get project root
|
| 108 |
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 109 |
|
| 110 |
+
logger.info(f"Initializing DiT handler with model: {dit_model_path}")
|
| 111 |
|
| 112 |
# Initialize DiT handler (handles main diffusion model, VAE, text encoder)
|
| 113 |
status_dit, success_dit = self.dit_handler.initialize_service(
|
| 114 |
project_root=project_root,
|
| 115 |
+
config_path=str(checkpoints_dir / dit_model_path),
|
| 116 |
device="auto",
|
| 117 |
use_flash_attention=False,
|
| 118 |
compile_model=False,
|
|
|
|
| 122 |
if not success_dit:
|
| 123 |
raise RuntimeError(f"Failed to initialize DiT: {status_dit}")
|
| 124 |
|
| 125 |
+
logger.info(f"✓ DiT initialized: {status_dit}")
|
| 126 |
+
|
| 127 |
# Initialize LLM handler (handles 5Hz Language Model)
|
| 128 |
logger.info(f"Initializing LLM handler with model: {lm_model_path}")
|
| 129 |
|
| 130 |
status_llm, success_llm = self.llm_handler.initialize(
|
| 131 |
+
checkpoint_dir=str(checkpoints_dir),
|
| 132 |
lm_model_path=lm_model_path,
|
| 133 |
backend="pt", # Use PyTorch backend for compatibility
|
| 134 |
device="auto",
|