ChuxiJ commited on
Commit
6813c41
·
1 Parent(s): b7a9cc9

fix: move @spaces.GPU decorators to local wrappers to fix ZeroGPU pickling error

Browse files

On ZeroGPU, Gradio serializes (pickles) event handler functions to send
them to GPU worker processes. When fn= was a lambda capturing dit_handler,
Python tried to pickle the entire handler including AceStepDiTModel which
contains unpicklable local objects in __init__, causing:
'Can't pickle local object AceStepDiTModel.__init__'

Fix: Remove @_get_spaces_gpu_decorator from module-level functions and
instead apply it to local wrapper functions defined inside
setup_event_handlers(), passed directly as fn= (same pattern as the
working generation_wrapper).

Affected handlers:
- process_source_audio (Analyze button)
- handle_create_sample (Create Sample button)
- handle_format_sample (Format button)
- calculate_score_handler_with_selection (Score buttons)
- generate_lrc_handler (LRC buttons)

acestep/gradio_ui/events/__init__.py CHANGED
@@ -263,10 +263,14 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
263
 
264
  # ========== Format Button ==========
265
  # Note: cfg_scale and negative_prompt are not supported in format mode
266
- generation_section["format_btn"].click(
267
- fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_sample(
 
268
  llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
269
- ),
 
 
 
270
  inputs=[
271
  generation_section["captions"],
272
  generation_section["lyrics"],
@@ -312,8 +316,16 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
312
 
313
  # ========== Process Source Audio Button ==========
314
  # Combines Convert to Codes + Transcribe in one step
 
 
 
 
 
 
 
 
315
  generation_section["process_src_btn"].click(
316
- fn=lambda src, debug: gen_h.process_source_audio(dit_handler, llm_handler, src, debug),
317
  inputs=[
318
  generation_section["src_audio"],
319
  generation_section["constrained_decoding_debug"]
@@ -353,10 +365,14 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
353
 
354
  # ========== Create Sample Button (Simple Mode) ==========
355
  # Note: cfg_scale and negative_prompt are not supported in create_sample mode
356
- generation_section["create_sample_btn"].click(
357
- fn=lambda query, instrumental, vocal_lang, temp, top_k, top_p, debug: gen_h.handle_create_sample(
 
358
  llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
359
- ),
 
 
 
360
  inputs=[
361
  generation_section["simple_query_input"],
362
  generation_section["simple_instrumental_checkbox"],
@@ -593,10 +609,15 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
593
 
594
  # ========== Score Calculation Handlers ==========
595
  # Use default argument to capture btn_idx value at definition time (Python closure fix)
 
 
596
  def make_score_handler(idx):
597
- return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
598
- dit_handler, llm_handler, idx, scale, batch_idx, queue
599
- )
 
 
 
600
 
601
  for btn_idx in range(1, 9):
602
  results_section[f"score_btn_{btn_idx}"].click(
@@ -616,9 +637,12 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
616
  # ========== LRC Timestamp Handlers ==========
617
  # Use default argument to capture btn_idx value at definition time (Python closure fix)
618
  def make_lrc_handler(idx):
619
- return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
620
- dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
621
- )
 
 
 
622
 
623
  for btn_idx in range(1, 9):
624
  results_section[f"lrc_btn_{btn_idx}"].click(
 
263
 
264
  # ========== Format Button ==========
265
  # Note: cfg_scale and negative_prompt are not supported in format mode
266
+ @_get_spaces_gpu_decorator(duration=180)
267
+ def handle_format_sample_wrapper(caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug):
268
+ return gen_h.handle_format_sample(
269
  llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
270
+ )
271
+
272
+ generation_section["format_btn"].click(
273
+ fn=handle_format_sample_wrapper,
274
  inputs=[
275
  generation_section["captions"],
276
  generation_section["lyrics"],
 
316
 
317
  # ========== Process Source Audio Button ==========
318
  # Combines Convert to Codes + Transcribe in one step
319
+ # Note: @spaces.GPU decorator must be on the function passed directly to fn=,
320
+ # not on a module-level function wrapped in a lambda. Lambdas capturing handler
321
+ # objects cause pickling errors on ZeroGPU because the model contains unpicklable
322
+ # local objects (e.g. AceStepDiTModel.__init__ lambdas).
323
+ @_get_spaces_gpu_decorator(duration=180)
324
+ def process_source_audio_wrapper(src, debug):
325
+ return gen_h.process_source_audio(dit_handler, llm_handler, src, debug)
326
+
327
  generation_section["process_src_btn"].click(
328
+ fn=process_source_audio_wrapper,
329
  inputs=[
330
  generation_section["src_audio"],
331
  generation_section["constrained_decoding_debug"]
 
365
 
366
  # ========== Create Sample Button (Simple Mode) ==========
367
  # Note: cfg_scale and negative_prompt are not supported in create_sample mode
368
+ @_get_spaces_gpu_decorator(duration=180)
369
+ def handle_create_sample_wrapper(query, instrumental, vocal_lang, temp, top_k, top_p, debug):
370
+ return gen_h.handle_create_sample(
371
  llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
372
+ )
373
+
374
+ generation_section["create_sample_btn"].click(
375
+ fn=handle_create_sample_wrapper,
376
  inputs=[
377
  generation_section["simple_query_input"],
378
  generation_section["simple_instrumental_checkbox"],
 
609
 
610
  # ========== Score Calculation Handlers ==========
611
  # Use default argument to capture btn_idx value at definition time (Python closure fix)
612
+ # Note: @spaces.GPU decorator applied here (not on module-level function) to avoid
613
+ # pickling issues on ZeroGPU when handler objects are captured in closures.
614
  def make_score_handler(idx):
615
+ @_get_spaces_gpu_decorator(duration=240)
616
+ def score_handler(scale, batch_idx, queue):
617
+ return res_h.calculate_score_handler_with_selection(
618
+ dit_handler, llm_handler, idx, scale, batch_idx, queue
619
+ )
620
+ return score_handler
621
 
622
  for btn_idx in range(1, 9):
623
  results_section[f"score_btn_{btn_idx}"].click(
 
637
  # ========== LRC Timestamp Handlers ==========
638
  # Use default argument to capture btn_idx value at definition time (Python closure fix)
639
  def make_lrc_handler(idx):
640
+ @_get_spaces_gpu_decorator(duration=240)
641
+ def lrc_handler(batch_idx, queue, vocal_lang, infer_steps):
642
+ return res_h.generate_lrc_handler(
643
+ dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
644
+ )
645
+ return lrc_handler
646
 
647
  for btn_idx in range(1, 9):
648
  results_section[f"lrc_btn_{btn_idx}"].click(
acestep/gradio_ui/events/generation_handlers.py CHANGED
@@ -766,7 +766,6 @@ def handle_generation_mode_change(mode: str):
766
  think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
767
  )
768
 
769
- @_get_spaces_gpu_decorator(duration=180)
770
  def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
771
  """
772
  Process source audio: convert to codes and then transcribe.
@@ -819,7 +818,6 @@ def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decodi
819
  True # Set is_format_caption to True
820
  )
821
 
822
- @_get_spaces_gpu_decorator(duration=180)
823
  def handle_create_sample(
824
  llm_handler,
825
  query: str,
@@ -949,7 +947,6 @@ def handle_create_sample(
949
  result.status_message, # status_output
950
  )
951
 
952
- @_get_spaces_gpu_decorator(duration=180)
953
  def handle_format_sample(
954
  llm_handler,
955
  caption: str,
 
766
  think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
767
  )
768
 
 
769
  def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
770
  """
771
  Process source audio: convert to codes and then transcribe.
 
818
  True # Set is_format_caption to True
819
  )
820
 
 
821
  def handle_create_sample(
822
  llm_handler,
823
  query: str,
 
947
  result.status_message, # status_output
948
  )
949
 
 
950
  def handle_format_sample(
951
  llm_handler,
952
  caption: str,
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -1058,7 +1058,6 @@ def calculate_score_handler(
1058
  error_msg = t("messages.score_error", error=str(e)) + f"\n{traceback.format_exc()}"
1059
  return error_msg
1060
 
1061
- @_get_spaces_gpu_decorator(duration=240)
1062
  def calculate_score_handler_with_selection(
1063
  dit_handler,
1064
  llm_handler,
@@ -1172,7 +1171,6 @@ def calculate_score_handler_with_selection(
1172
  batch_queue
1173
  )
1174
 
1175
- @_get_spaces_gpu_decorator(duration=240)
1176
  def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_queue, vocal_language, inference_steps):
1177
  """
1178
  Generate LRC timestamps for a specific audio sample.
 
1058
  error_msg = t("messages.score_error", error=str(e)) + f"\n{traceback.format_exc()}"
1059
  return error_msg
1060
 
 
1061
  def calculate_score_handler_with_selection(
1062
  dit_handler,
1063
  llm_handler,
 
1171
  batch_queue
1172
  )
1173
 
 
1174
  def generate_lrc_handler(dit_handler, sample_idx, current_batch_index, batch_queue, vocal_language, inference_steps):
1175
  """
1176
  Generate LRC timestamps for a specific audio sample.