ChuxiJ commited on
Commit
fb234f6
·
1 Parent(s): 42e0725

support shift3 download

Browse files
Files changed (1) hide show
  1. acestep/handler.py +47 -12
acestep/handler.py CHANGED
@@ -117,16 +117,39 @@ class AceStepHandler:
117
  models.sort()
118
  return models
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def _ensure_model_downloaded(self, model_name: str, checkpoint_dir: str) -> str:
121
  """
122
  Ensure model is downloaded from HuggingFace Hub.
123
  Used for HuggingFace Space auto-download support.
124
 
125
- Downloads the unified ACE-Step/Ace-Step1.5 repository which contains
126
- both acestep-v15-turbo and acestep-5Hz-lm-1.7B models.
 
 
 
 
127
 
128
  Args:
129
- model_name: Model directory name (e.g., "acestep-v15-turbo")
130
  checkpoint_dir: Target checkpoint directory
131
 
132
  Returns:
@@ -134,9 +157,6 @@ class AceStepHandler:
134
  """
135
  from huggingface_hub import snapshot_download
136
 
137
- # Unified repository containing all models
138
- REPO_ID = "ACE-Step/Ace-Step1.5"
139
-
140
  model_path = os.path.join(checkpoint_dir, model_name)
141
 
142
  # Check if model already exists
@@ -144,18 +164,33 @@ class AceStepHandler:
144
  logger.info(f"Model {model_name} already exists at {model_path}")
145
  return model_path
146
 
147
- # Download the entire repository to checkpoint_dir
148
- logger.info(f"Downloading {REPO_ID} to {checkpoint_dir}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  try:
151
  snapshot_download(
152
- repo_id=REPO_ID,
153
- local_dir=checkpoint_dir,
154
  local_dir_use_symlinks=False,
155
  )
156
- logger.info(f"Repository {REPO_ID} downloaded successfully")
157
  except Exception as e:
158
- logger.error(f"Failed to download repository {REPO_ID}: {e}")
159
  raise
160
 
161
  return model_path
 
117
  models.sort()
118
  return models
119
 
120
+ # Model name to HuggingFace repository mapping
121
+ # Models in the same repo will be downloaded together
122
+ MODEL_REPO_MAPPING = {
123
+ # Main unified repository (contains acestep-v15-turbo, LM models, VAE, text encoder)
124
+ "acestep-v15-turbo": "ACE-Step/Ace-Step1.5",
125
+ "acestep-5Hz-lm-0.6B": "ACE-Step/Ace-Step1.5",
126
+ "acestep-5Hz-lm-1.7B": "ACE-Step/Ace-Step1.5",
127
+ "vae": "ACE-Step/Ace-Step1.5",
128
+ "Qwen3-Embedding-0.6B": "ACE-Step/Ace-Step1.5",
129
+
130
+ # Separate model repositories
131
+ "acestep-v15-base": "ACE-Step/acestep-v15-base",
132
+ "acestep-v15-sft": "ACE-Step/acestep-v15-sft",
133
+ "acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3",
134
+ }
135
+
136
+ # Default fallback repository for unknown models
137
+ DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5"
138
+
139
  def _ensure_model_downloaded(self, model_name: str, checkpoint_dir: str) -> str:
140
  """
141
  Ensure model is downloaded from HuggingFace Hub.
142
  Used for HuggingFace Space auto-download support.
143
 
144
+ Supports multiple repositories:
145
+ - Models in MODEL_REPO_MAPPING will be downloaded from their specific repo
146
+ - Unknown models will try the DEFAULT_REPO_ID
147
+
148
+ For separate model repos (acestep-v15-base, acestep-v15-sft, acestep-v15-turbo-shift3),
149
+ downloads directly into the model subdirectory.
150
 
151
  Args:
152
+ model_name: Model directory name (e.g., "acestep-v15-turbo", "acestep-v15-turbo-shift3")
153
  checkpoint_dir: Target checkpoint directory
154
 
155
  Returns:
 
157
  """
158
  from huggingface_hub import snapshot_download
159
 
 
 
 
160
  model_path = os.path.join(checkpoint_dir, model_name)
161
 
162
  # Check if model already exists
 
164
  logger.info(f"Model {model_name} already exists at {model_path}")
165
  return model_path
166
 
167
+ # Get repository ID for this model
168
+ repo_id = self.MODEL_REPO_MAPPING.get(model_name, self.DEFAULT_REPO_ID)
169
+
170
+ # Determine if this is a unified repo or a separate model repo
171
+ is_unified_repo = repo_id == self.DEFAULT_REPO_ID or repo_id == "ACE-Step/Ace-Step1.5"
172
+
173
+ if is_unified_repo:
174
+ # Unified repo: download entire repo to checkpoint_dir
175
+ # The model will be in checkpoint_dir/model_name
176
+ download_dir = checkpoint_dir
177
+ logger.info(f"Downloading unified repository {repo_id} to {download_dir}...")
178
+ else:
179
+ # Separate model repo: download directly to model_path
180
+ # The repo contains the model files directly, not in a subdirectory
181
+ download_dir = model_path
182
+ os.makedirs(download_dir, exist_ok=True)
183
+ logger.info(f"Downloading model {model_name} from {repo_id} to {download_dir}...")
184
 
185
  try:
186
  snapshot_download(
187
+ repo_id=repo_id,
188
+ local_dir=download_dir,
189
  local_dir_use_symlinks=False,
190
  )
191
+ logger.info(f"Repository {repo_id} downloaded successfully to {download_dir}")
192
  except Exception as e:
193
+ logger.error(f"Failed to download repository {repo_id}: {e}")
194
  raise
195
 
196
  return model_path