ACE-Step Custom commited on
Commit
052ca84
Β·
1 Parent(s): 6ccd18b

Fix: Implement lazy model loading for ZeroGPU compatibility

Browse files

- Models now load on first use instead of startup

- Device detection happens within GPU context

- Added GPU duration timeouts for all generation functions

- Ensures GPU is acquired before model initialization

Files changed (2) hide show
  1. app.py +6 -6
  2. src/ace_step_engine.py +30 -23
app.py CHANGED
@@ -61,7 +61,7 @@ def get_audio_processor():
61
 
62
  # ==================== TAB 1: STANDARD ACE-STEP GUI ====================
63
 
64
- @spaces.GPU
65
  def standard_generate(
66
  prompt: str,
67
  lyrics: str,
@@ -100,7 +100,7 @@ def standard_generate(
100
  return None, f"❌ Error: {str(e)}"
101
 
102
 
103
- @spaces.GPU
104
  def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str, str]:
105
  """Generate variation of existing audio."""
106
  try:
@@ -110,7 +110,7 @@ def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str,
110
  return None, f"❌ Error: {str(e)}"
111
 
112
 
113
- @spaces.GPU
114
  def standard_repaint(
115
  audio_path: str,
116
  start_time: float,
@@ -124,7 +124,7 @@ def standard_repaint(
124
  except Exception as e:
125
  return None, f"❌ Error: {str(e)}"
126
 
127
- @spaces.GPU
128
 
129
  def standard_lyric_edit(
130
  audio_path: str,
@@ -139,7 +139,7 @@ def standard_lyric_edit(
139
 
140
 
141
  # ==================== TAB 2: CUSTOM TIMELINE WORKFLOW ====================
142
- @spaces.GPU
143
 
144
  def timeline_generate(
145
  prompt: str,
@@ -232,7 +232,7 @@ def timeline_extend(
232
  prompt, lyrics, context_length, "auto", 0.7, -1, session_state
233
  )
234
 
235
- @spaces.GPU
236
 
237
  def timeline_inpaint(
238
  start_time: float,
 
61
 
62
  # ==================== TAB 1: STANDARD ACE-STEP GUI ====================
63
 
64
+ @spaces.GPU(duration=300)
65
  def standard_generate(
66
  prompt: str,
67
  lyrics: str,
 
100
  return None, f"❌ Error: {str(e)}"
101
 
102
 
103
+ @spaces.GPU(duration=180)
104
  def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str, str]:
105
  """Generate variation of existing audio."""
106
  try:
 
110
  return None, f"❌ Error: {str(e)}"
111
 
112
 
113
+ @spaces.GPU(duration=180)
114
  def standard_repaint(
115
  audio_path: str,
116
  start_time: float,
 
124
  except Exception as e:
125
  return None, f"❌ Error: {str(e)}"
126
 
127
+ @spaces.GPU(duration=180)
128
 
129
  def standard_lyric_edit(
130
  audio_path: str,
 
139
 
140
 
141
  # ==================== TAB 2: CUSTOM TIMELINE WORKFLOW ====================
142
+ @spaces.GPU(duration=300)
143
 
144
  def timeline_generate(
145
  prompt: str,
 
232
  prompt, lyrics, context_length, "auto", 0.7, -1, session_state
233
  )
234
 
235
+ @spaces.GPU(duration=240)
236
 
237
  def timeline_inpaint(
238
  start_time: float,
src/ace_step_engine.py CHANGED
@@ -34,33 +34,18 @@ class ACEStepEngine:
34
  config: Configuration dictionary
35
  """
36
  self.config = config
37
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  self._initialized = False
39
  self.dit_handler = None
40
  self.llm_handler = None
41
 
42
- logger.info(f"ACE-Step Engine initializing on {self.device}")
43
 
44
  if not ACE_STEP_AVAILABLE:
45
  logger.error("ACE-Step 1.5 modules not available")
46
  logger.error("Please ensure acestep package is installed in your environment")
47
  return
48
 
49
- try:
50
- # Initialize official handlers
51
- self.dit_handler = AceStepHandler()
52
- self.llm_handler = LLMHandler()
53
-
54
- # Download and load models
55
- self._download_checkpoints()
56
- self._load_models()
57
-
58
- logger.info("βœ“ ACE-Step Engine fully initialized")
59
- except Exception as e:
60
- logger.error(f"Failed to initialize ACE-Step Engine: {e}")
61
- logger.error("Engine will not be available for generation")
62
- import traceback
63
- traceback.print_exc()
64
 
65
  def _download_checkpoints(self):
66
  """Download model checkpoints from HuggingFace if not present."""
@@ -150,6 +135,30 @@ class ACEStepEngine:
150
  logger.error(f"Failed to initialize models: {e}")
151
  raise
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def generate(
154
  self,
155
  prompt: str,
@@ -177,10 +186,8 @@ class ACEStepEngine:
177
  Returns:
178
  Path to generated audio file
179
  """
180
- if not self._initialized:
181
- error_msg = "❌ Engine not initialized - ACE-Step 1.5 may not be installed or models are not loaded"
182
- logger.error(error_msg)
183
- raise RuntimeError(error_msg)
184
 
185
  try:
186
  # Prepare generation parameters
@@ -266,8 +273,8 @@ class ACEStepEngine:
266
 
267
  def generate_variation(self, audio_path: str, strength: float = 0.5) -> str:
268
  """Generate variation of existing audio."""
269
- if not self._initialized:
270
- raise RuntimeError("Engine not initialized")
271
 
272
  try:
273
  params = GenerationParams(
 
34
  config: Configuration dictionary
35
  """
36
  self.config = config
 
37
  self._initialized = False
38
  self.dit_handler = None
39
  self.llm_handler = None
40
 
41
+ logger.info(f"ACE-Step Engine created (GPU will be detected on first use)")
42
 
43
  if not ACE_STEP_AVAILABLE:
44
  logger.error("ACE-Step 1.5 modules not available")
45
  logger.error("Please ensure acestep package is installed in your environment")
46
  return
47
 
48
+ logger.info("βœ“ ACE-Step Engine created (models will load on first use)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def _download_checkpoints(self):
51
  """Download model checkpoints from HuggingFace if not present."""
 
135
  logger.error(f"Failed to initialize models: {e}")
136
  raise
137
 
138
+ def _ensure_models_loaded(self):
139
+ """Ensure models are loaded (lazy loading for ZeroGPU compatibility)."""
140
+ if not self._initialized:
141
+ logger.info("Lazy loading models on first use...")
142
+
143
+ # Detect device now (within GPU context on ZeroGPU)
144
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145
+ logger.info(f"Using device: {self.device}")
146
+
147
+ # Create handlers if not already created
148
+ if self.dit_handler is None:
149
+ self.dit_handler = AceStepHandler()
150
+ if self.llm_handler is None:
151
+ self.llm_handler = LLMHandler()
152
+
153
+ try:
154
+ # Download and load models
155
+ self._download_checkpoints()
156
+ self._load_models()
157
+ logger.info("βœ“ Models loaded successfully")
158
+ except Exception as e:
159
+ logger.error(f"Failed to load models: {e}")
160
+ raise
161
+
162
  def generate(
163
  self,
164
  prompt: str,
 
186
  Returns:
187
  Path to generated audio file
188
  """
189
+ # Ensure models are loaded (lazy loading for ZeroGPU)
190
+ self._ensure_models_loaded()
 
 
191
 
192
  try:
193
  # Prepare generation parameters
 
273
 
274
  def generate_variation(self, audio_path: str, strength: float = 0.5) -> str:
275
  """Generate variation of existing audio."""
276
+ # Ensure models are loaded (lazy loading for ZeroGPU)
277
+ self._ensure_models_loaded()
278
 
279
  try:
280
  params = GenerationParams(