Spaces:
Running
on
Zero
Running
on
Zero
support shift3 download
Browse files- acestep/handler.py +47 -12
acestep/handler.py
CHANGED
|
@@ -117,16 +117,39 @@ class AceStepHandler:
|
|
| 117 |
models.sort()
|
| 118 |
return models
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
def _ensure_model_downloaded(self, model_name: str, checkpoint_dir: str) -> str:
|
| 121 |
"""
|
| 122 |
Ensure model is downloaded from HuggingFace Hub.
|
| 123 |
Used for HuggingFace Space auto-download support.
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
Args:
|
| 129 |
-
model_name: Model directory name (e.g., "acestep-v15-turbo")
|
| 130 |
checkpoint_dir: Target checkpoint directory
|
| 131 |
|
| 132 |
Returns:
|
|
@@ -134,9 +157,6 @@ class AceStepHandler:
|
|
| 134 |
"""
|
| 135 |
from huggingface_hub import snapshot_download
|
| 136 |
|
| 137 |
-
# Unified repository containing all models
|
| 138 |
-
REPO_ID = "ACE-Step/Ace-Step1.5"
|
| 139 |
-
|
| 140 |
model_path = os.path.join(checkpoint_dir, model_name)
|
| 141 |
|
| 142 |
# Check if model already exists
|
|
@@ -144,18 +164,33 @@ class AceStepHandler:
|
|
| 144 |
logger.info(f"Model {model_name} already exists at {model_path}")
|
| 145 |
return model_path
|
| 146 |
|
| 147 |
-
#
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
try:
|
| 151 |
snapshot_download(
|
| 152 |
-
repo_id=
|
| 153 |
-
local_dir=
|
| 154 |
local_dir_use_symlinks=False,
|
| 155 |
)
|
| 156 |
-
logger.info(f"Repository {
|
| 157 |
except Exception as e:
|
| 158 |
-
logger.error(f"Failed to download repository {
|
| 159 |
raise
|
| 160 |
|
| 161 |
return model_path
|
|
|
|
| 117 |
models.sort()
|
| 118 |
return models
|
| 119 |
|
| 120 |
+
# Model name to HuggingFace repository mapping
|
| 121 |
+
# Models in the same repo will be downloaded together
|
| 122 |
+
MODEL_REPO_MAPPING = {
|
| 123 |
+
# Main unified repository (contains acestep-v15-turbo, LM models, VAE, text encoder)
|
| 124 |
+
"acestep-v15-turbo": "ACE-Step/Ace-Step1.5",
|
| 125 |
+
"acestep-5Hz-lm-0.6B": "ACE-Step/Ace-Step1.5",
|
| 126 |
+
"acestep-5Hz-lm-1.7B": "ACE-Step/Ace-Step1.5",
|
| 127 |
+
"vae": "ACE-Step/Ace-Step1.5",
|
| 128 |
+
"Qwen3-Embedding-0.6B": "ACE-Step/Ace-Step1.5",
|
| 129 |
+
|
| 130 |
+
# Separate model repositories
|
| 131 |
+
"acestep-v15-base": "ACE-Step/acestep-v15-base",
|
| 132 |
+
"acestep-v15-sft": "ACE-Step/acestep-v15-sft",
|
| 133 |
+
"acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3",
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
# Default fallback repository for unknown models
|
| 137 |
+
DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5"
|
| 138 |
+
|
| 139 |
def _ensure_model_downloaded(self, model_name: str, checkpoint_dir: str) -> str:
|
| 140 |
"""
|
| 141 |
Ensure model is downloaded from HuggingFace Hub.
|
| 142 |
Used for HuggingFace Space auto-download support.
|
| 143 |
|
| 144 |
+
Supports multiple repositories:
|
| 145 |
+
- Models in MODEL_REPO_MAPPING will be downloaded from their specific repo
|
| 146 |
+
- Unknown models will try the DEFAULT_REPO_ID
|
| 147 |
+
|
| 148 |
+
For separate model repos (acestep-v15-base, acestep-v15-sft, acestep-v15-turbo-shift3),
|
| 149 |
+
downloads directly into the model subdirectory.
|
| 150 |
|
| 151 |
Args:
|
| 152 |
+
model_name: Model directory name (e.g., "acestep-v15-turbo", "acestep-v15-turbo-shift3")
|
| 153 |
checkpoint_dir: Target checkpoint directory
|
| 154 |
|
| 155 |
Returns:
|
|
|
|
| 157 |
"""
|
| 158 |
from huggingface_hub import snapshot_download
|
| 159 |
|
|
|
|
|
|
|
|
|
|
| 160 |
model_path = os.path.join(checkpoint_dir, model_name)
|
| 161 |
|
| 162 |
# Check if model already exists
|
|
|
|
| 164 |
logger.info(f"Model {model_name} already exists at {model_path}")
|
| 165 |
return model_path
|
| 166 |
|
| 167 |
+
# Get repository ID for this model
|
| 168 |
+
repo_id = self.MODEL_REPO_MAPPING.get(model_name, self.DEFAULT_REPO_ID)
|
| 169 |
+
|
| 170 |
+
# Determine if this is a unified repo or a separate model repo
|
| 171 |
+
is_unified_repo = repo_id == self.DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5"
|
| 172 |
+
|
| 173 |
+
if is_unified_repo:
|
| 174 |
+
# Unified repo: download entire repo to checkpoint_dir
|
| 175 |
+
# The model will be in checkpoint_dir/model_name
|
| 176 |
+
download_dir = checkpoint_dir
|
| 177 |
+
logger.info(f"Downloading unified repository {repo_id} to {download_dir}...")
|
| 178 |
+
else:
|
| 179 |
+
# Separate model repo: download directly to model_path
|
| 180 |
+
# The repo contains the model files directly, not in a subdirectory
|
| 181 |
+
download_dir = model_path
|
| 182 |
+
os.makedirs(download_dir, exist_ok=True)
|
| 183 |
+
logger.info(f"Downloading model {model_name} from {repo_id} to {download_dir}...")
|
| 184 |
|
| 185 |
try:
|
| 186 |
snapshot_download(
|
| 187 |
+
repo_id=repo_id,
|
| 188 |
+
local_dir=download_dir,
|
| 189 |
local_dir_use_symlinks=False,
|
| 190 |
)
|
| 191 |
+
logger.info(f"Repository {repo_id} downloaded successfully to {download_dir}")
|
| 192 |
except Exception as e:
|
| 193 |
+
logger.error(f"Failed to download repository {repo_id}: {e}")
|
| 194 |
raise
|
| 195 |
|
| 196 |
return model_path
|