ChuxiJ commited on
Commit
bb87271
·
1 Parent(s): 858eb3e

add shift support

Browse files
acestep/gradio_ui/events/__init__.py CHANGED
@@ -34,6 +34,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
34
  generation_section["inference_steps"],
35
  generation_section["guidance_scale"],
36
  generation_section["use_adg"],
 
37
  generation_section["cfg_interval_start"],
38
  generation_section["cfg_interval_end"],
39
  generation_section["task_type"],
@@ -235,12 +236,14 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
235
  generation_section["use_adg"],
236
  generation_section["cfg_interval_start"],
237
  generation_section["cfg_interval_end"],
 
238
  generation_section["audio_format"],
239
  generation_section["lm_temperature"],
240
  generation_section["lm_cfg_scale"],
241
  generation_section["lm_top_k"],
242
  generation_section["lm_top_p"],
243
  generation_section["lm_negative_prompt"],
 
244
  generation_section["use_cot_caption"],
245
  generation_section["use_cot_language"],
246
  generation_section["audio_cover_strength"],
@@ -250,6 +253,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
250
  generation_section["repainting_end"],
251
  generation_section["track_name"],
252
  generation_section["complete_track_classes"],
 
253
  results_section["is_format_caption_state"]
254
  ]
255
  )
@@ -396,6 +400,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
396
  generation_section["use_adg"],
397
  generation_section["cfg_interval_start"],
398
  generation_section["cfg_interval_end"],
 
399
  generation_section["audio_format"],
400
  generation_section["lm_temperature"],
401
  generation_section["think_checkbox"],
@@ -547,6 +552,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
547
  generation_section["use_adg"],
548
  generation_section["cfg_interval_start"],
549
  generation_section["cfg_interval_end"],
 
550
  generation_section["audio_format"],
551
  generation_section["lm_temperature"],
552
  generation_section["think_checkbox"],
 
34
  generation_section["inference_steps"],
35
  generation_section["guidance_scale"],
36
  generation_section["use_adg"],
37
+ generation_section["shift"],
38
  generation_section["cfg_interval_start"],
39
  generation_section["cfg_interval_end"],
40
  generation_section["task_type"],
 
236
  generation_section["use_adg"],
237
  generation_section["cfg_interval_start"],
238
  generation_section["cfg_interval_end"],
239
+ generation_section["shift"],
240
  generation_section["audio_format"],
241
  generation_section["lm_temperature"],
242
  generation_section["lm_cfg_scale"],
243
  generation_section["lm_top_k"],
244
  generation_section["lm_top_p"],
245
  generation_section["lm_negative_prompt"],
246
+ generation_section["use_cot_metas"], # Added: use_cot_metas
247
  generation_section["use_cot_caption"],
248
  generation_section["use_cot_language"],
249
  generation_section["audio_cover_strength"],
 
253
  generation_section["repainting_end"],
254
  generation_section["track_name"],
255
  generation_section["complete_track_classes"],
256
+ generation_section["instrumental_checkbox"], # Added: instrumental_checkbox
257
  results_section["is_format_caption_state"]
258
  ]
259
  )
 
400
  generation_section["use_adg"],
401
  generation_section["cfg_interval_start"],
402
  generation_section["cfg_interval_end"],
403
+ generation_section["shift"],
404
  generation_section["audio_format"],
405
  generation_section["lm_temperature"],
406
  generation_section["think_checkbox"],
 
552
  generation_section["use_adg"],
553
  generation_section["cfg_interval_start"],
554
  generation_section["cfg_interval_end"],
555
+ generation_section["shift"],
556
  generation_section["audio_format"],
557
  generation_section["lm_temperature"],
558
  generation_section["think_checkbox"],
acestep/gradio_ui/events/generation_handlers.py CHANGED
@@ -19,7 +19,7 @@ def load_metadata(file_obj):
19
  """Load generation parameters from a JSON file"""
20
  if file_obj is None:
21
  gr.Warning(t("messages.no_file_selected"))
22
- return [None] * 31 + [False] # Return None for all fields, False for is_format_caption
23
 
24
  try:
25
  # Read the uploaded file
@@ -74,35 +74,38 @@ def load_metadata(file_obj):
74
  lm_top_k = metadata.get('lm_top_k', 0)
75
  lm_top_p = metadata.get('lm_top_p', 0.9)
76
  lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
 
77
  use_cot_caption = metadata.get('use_cot_caption', True)
78
  use_cot_language = metadata.get('use_cot_language', True)
79
  audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
80
- think = metadata.get('think', True)
81
  audio_codes = metadata.get('audio_codes', '')
82
  repainting_start = metadata.get('repainting_start', 0.0)
83
  repainting_end = metadata.get('repainting_end', -1)
84
  track_name = metadata.get('track_name')
85
  complete_track_classes = metadata.get('complete_track_classes', [])
 
 
86
 
87
  gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
88
 
89
  return (
90
  task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
91
  audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
92
- use_adg, cfg_interval_start, cfg_interval_end, audio_format,
93
  lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
94
- use_cot_caption, use_cot_language, audio_cover_strength,
95
  think, audio_codes, repainting_start, repainting_end,
96
- track_name, complete_track_classes,
97
  True # Set is_format_caption to True when loading from file
98
  )
99
 
100
  except json.JSONDecodeError as e:
101
  gr.Warning(t("messages.invalid_json", error=str(e)))
102
- return [None] * 31 + [False]
103
  except Exception as e:
104
  gr.Warning(t("messages.load_error", error=str(e)))
105
- return [None] * 31 + [False]
106
 
107
 
108
  def load_random_example(task_type: str):
@@ -282,21 +285,25 @@ def update_model_type_settings(config_path):
282
  config_path_lower = config_path.lower()
283
 
284
  if "turbo" in config_path_lower:
285
- # Turbo model: max 8 steps, hide CFG/ADG, only show text2music/repaint/cover
 
286
  return (
287
  gr.update(value=8, maximum=8, minimum=1), # inference_steps
288
  gr.update(visible=False), # guidance_scale
289
  gr.update(visible=False), # use_adg
 
290
  gr.update(visible=False), # cfg_interval_start
291
  gr.update(visible=False), # cfg_interval_end
292
  gr.update(choices=TASK_TYPES_TURBO), # task_type
293
  )
294
  elif "base" in config_path_lower:
295
- # Base model: max 100 steps, show CFG/ADG, show all task types
 
296
  return (
297
  gr.update(value=32, maximum=100, minimum=1), # inference_steps
298
  gr.update(visible=True), # guidance_scale
299
  gr.update(visible=True), # use_adg
 
300
  gr.update(visible=True), # cfg_interval_start
301
  gr.update(visible=True), # cfg_interval_end
302
  gr.update(choices=TASK_TYPES_BASE), # task_type
@@ -307,6 +314,7 @@ def update_model_type_settings(config_path):
307
  gr.update(value=8, maximum=8, minimum=1),
308
  gr.update(visible=False),
309
  gr.update(visible=False),
 
310
  gr.update(visible=False),
311
  gr.update(visible=False),
312
  gr.update(choices=TASK_TYPES_TURBO), # task_type
 
19
  """Load generation parameters from a JSON file"""
20
  if file_obj is None:
21
  gr.Warning(t("messages.no_file_selected"))
22
+ return [None] * 34 + [False] # Return None for all fields, False for is_format_caption
23
 
24
  try:
25
  # Read the uploaded file
 
74
  lm_top_k = metadata.get('lm_top_k', 0)
75
  lm_top_p = metadata.get('lm_top_p', 0.9)
76
  lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
77
+ use_cot_metas = metadata.get('use_cot_metas', True) # Added: read use_cot_metas
78
  use_cot_caption = metadata.get('use_cot_caption', True)
79
  use_cot_language = metadata.get('use_cot_language', True)
80
  audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
81
+ think = metadata.get('thinking', True) # Fixed: read 'thinking' not 'think'
82
  audio_codes = metadata.get('audio_codes', '')
83
  repainting_start = metadata.get('repainting_start', 0.0)
84
  repainting_end = metadata.get('repainting_end', -1)
85
  track_name = metadata.get('track_name')
86
  complete_track_classes = metadata.get('complete_track_classes', [])
87
+ shift = metadata.get('shift', 3.0) # Default 3.0 for base models
88
+ instrumental = metadata.get('instrumental', False) # Added: read instrumental
89
 
90
  gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
91
 
92
  return (
93
  task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
94
  audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
95
+ use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format,
96
  lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
97
+ use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
98
  think, audio_codes, repainting_start, repainting_end,
99
+ track_name, complete_track_classes, instrumental,
100
  True # Set is_format_caption to True when loading from file
101
  )
102
 
103
  except json.JSONDecodeError as e:
104
  gr.Warning(t("messages.invalid_json", error=str(e)))
105
+ return [None] * 34 + [False]
106
  except Exception as e:
107
  gr.Warning(t("messages.load_error", error=str(e)))
108
+ return [None] * 34 + [False]
109
 
110
 
111
  def load_random_example(task_type: str):
 
285
  config_path_lower = config_path.lower()
286
 
287
  if "turbo" in config_path_lower:
288
+ # Turbo model: max 8 steps, hide CFG/ADG/shift, only show text2music/repaint/cover
289
+ # Shift is not effective for turbo models, default to 1.0
290
  return (
291
  gr.update(value=8, maximum=8, minimum=1), # inference_steps
292
  gr.update(visible=False), # guidance_scale
293
  gr.update(visible=False), # use_adg
294
+ gr.update(value=1.0, visible=False), # shift (not effective for turbo)
295
  gr.update(visible=False), # cfg_interval_start
296
  gr.update(visible=False), # cfg_interval_end
297
  gr.update(choices=TASK_TYPES_TURBO), # task_type
298
  )
299
  elif "base" in config_path_lower:
300
+ # Base model: max 100 steps, show CFG/ADG/shift, show all task types
301
+ # Shift range 1.0~5.0, default 3.0 for base models
302
  return (
303
  gr.update(value=32, maximum=100, minimum=1), # inference_steps
304
  gr.update(visible=True), # guidance_scale
305
  gr.update(visible=True), # use_adg
306
+ gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
307
  gr.update(visible=True), # cfg_interval_start
308
  gr.update(visible=True), # cfg_interval_end
309
  gr.update(choices=TASK_TYPES_BASE), # task_type
 
314
  gr.update(value=8, maximum=8, minimum=1),
315
  gr.update(visible=False),
316
  gr.update(visible=False),
317
+ gr.update(value=1.0, visible=False), # shift default 1.0
318
  gr.update(visible=False),
319
  gr.update(visible=False),
320
  gr.update(choices=TASK_TYPES_TURBO), # task_type
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -267,7 +267,7 @@ def generate_with_progress(
267
  reference_audio, audio_duration, batch_size_input, src_audio,
268
  text2music_audio_code_string, repainting_start, repainting_end,
269
  instruction_display_gen, audio_cover_strength, task_type,
270
- use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
271
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
272
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
273
  constrained_decoding_debug,
@@ -300,6 +300,7 @@ def generate_with_progress(
300
  use_adg=use_adg,
301
  cfg_interval_start=cfg_interval_start,
302
  cfg_interval_end=cfg_interval_end,
 
303
  repainting_start=repainting_start,
304
  repainting_end=repainting_end,
305
  audio_cover_strength=audio_cover_strength,
@@ -650,7 +651,7 @@ def capture_current_params(
650
  reference_audio, audio_duration, batch_size_input, src_audio,
651
  text2music_audio_code_string, repainting_start, repainting_end,
652
  instruction_display_gen, audio_cover_strength, task_type,
653
- use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
654
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
655
  use_cot_metas, use_cot_caption, use_cot_language,
656
  constrained_decoding_debug, allow_lm_batch, auto_score, score_scale, lm_batch_chunk_size,
@@ -686,6 +687,7 @@ def capture_current_params(
686
  "use_adg": use_adg,
687
  "cfg_interval_start": cfg_interval_start,
688
  "cfg_interval_end": cfg_interval_end,
 
689
  "audio_format": audio_format,
690
  "lm_temperature": lm_temperature,
691
  "think_checkbox": think_checkbox,
@@ -713,7 +715,7 @@ def generate_with_batch_management(
713
  reference_audio, audio_duration, batch_size_input, src_audio,
714
  text2music_audio_code_string, repainting_start, repainting_end,
715
  instruction_display_gen, audio_cover_strength, task_type,
716
- use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
717
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
718
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
719
  constrained_decoding_debug,
@@ -741,7 +743,7 @@ def generate_with_batch_management(
741
  reference_audio, audio_duration, batch_size_input, src_audio,
742
  text2music_audio_code_string, repainting_start, repainting_end,
743
  instruction_display_gen, audio_cover_strength, task_type,
744
- use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
745
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
746
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
747
  constrained_decoding_debug,
@@ -811,6 +813,7 @@ def generate_with_batch_management(
811
  "use_adg": use_adg,
812
  "cfg_interval_start": cfg_interval_start,
813
  "cfg_interval_end": cfg_interval_end,
 
814
  "audio_format": audio_format,
815
  "lm_temperature": lm_temperature,
816
  "think_checkbox": think_checkbox,
@@ -964,6 +967,7 @@ def generate_next_batch_background(
964
  params.setdefault("use_adg", False)
965
  params.setdefault("cfg_interval_start", 0.0)
966
  params.setdefault("cfg_interval_end", 1.0)
 
967
  params.setdefault("audio_format", "mp3")
968
  params.setdefault("lm_temperature", 0.85)
969
  params.setdefault("think_checkbox", True)
@@ -1010,6 +1014,7 @@ def generate_next_batch_background(
1010
  use_adg=params.get("use_adg"),
1011
  cfg_interval_start=params.get("cfg_interval_start"),
1012
  cfg_interval_end=params.get("cfg_interval_end"),
 
1013
  audio_format=params.get("audio_format"),
1014
  lm_temperature=params.get("lm_temperature"),
1015
  think_checkbox=params.get("think_checkbox"),
 
267
  reference_audio, audio_duration, batch_size_input, src_audio,
268
  text2music_audio_code_string, repainting_start, repainting_end,
269
  instruction_display_gen, audio_cover_strength, task_type,
270
+ use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format, lm_temperature,
271
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
272
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
273
  constrained_decoding_debug,
 
300
  use_adg=use_adg,
301
  cfg_interval_start=cfg_interval_start,
302
  cfg_interval_end=cfg_interval_end,
303
+ shift=shift,
304
  repainting_start=repainting_start,
305
  repainting_end=repainting_end,
306
  audio_cover_strength=audio_cover_strength,
 
651
  reference_audio, audio_duration, batch_size_input, src_audio,
652
  text2music_audio_code_string, repainting_start, repainting_end,
653
  instruction_display_gen, audio_cover_strength, task_type,
654
+ use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format, lm_temperature,
655
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
656
  use_cot_metas, use_cot_caption, use_cot_language,
657
  constrained_decoding_debug, allow_lm_batch, auto_score, score_scale, lm_batch_chunk_size,
 
687
  "use_adg": use_adg,
688
  "cfg_interval_start": cfg_interval_start,
689
  "cfg_interval_end": cfg_interval_end,
690
+ "shift": shift,
691
  "audio_format": audio_format,
692
  "lm_temperature": lm_temperature,
693
  "think_checkbox": think_checkbox,
 
715
  reference_audio, audio_duration, batch_size_input, src_audio,
716
  text2music_audio_code_string, repainting_start, repainting_end,
717
  instruction_display_gen, audio_cover_strength, task_type,
718
+ use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format, lm_temperature,
719
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
720
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
721
  constrained_decoding_debug,
 
743
  reference_audio, audio_duration, batch_size_input, src_audio,
744
  text2music_audio_code_string, repainting_start, repainting_end,
745
  instruction_display_gen, audio_cover_strength, task_type,
746
+ use_adg, cfg_interval_start, cfg_interval_end, shift, audio_format, lm_temperature,
747
  think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
748
  use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
749
  constrained_decoding_debug,
 
813
  "use_adg": use_adg,
814
  "cfg_interval_start": cfg_interval_start,
815
  "cfg_interval_end": cfg_interval_end,
816
+ "shift": shift,
817
  "audio_format": audio_format,
818
  "lm_temperature": lm_temperature,
819
  "think_checkbox": think_checkbox,
 
967
  params.setdefault("use_adg", False)
968
  params.setdefault("cfg_interval_start", 0.0)
969
  params.setdefault("cfg_interval_end", 1.0)
970
+ params.setdefault("shift", 1.0)
971
  params.setdefault("audio_format", "mp3")
972
  params.setdefault("lm_temperature", 0.85)
973
  params.setdefault("think_checkbox", True)
 
1014
  use_adg=params.get("use_adg"),
1015
  cfg_interval_start=params.get("cfg_interval_start"),
1016
  cfg_interval_end=params.get("cfg_interval_end"),
1017
+ shift=params.get("shift"),
1018
  audio_format=params.get("audio_format"),
1019
  lm_temperature=params.get("lm_temperature"),
1020
  think_checkbox=params.get("think_checkbox"),
acestep/gradio_ui/i18n/en.json CHANGED
@@ -116,6 +116,8 @@
116
  "audio_format_info": "Audio format for saved files",
117
  "use_adg_label": "Use ADG",
118
  "use_adg_info": "Enable Angle Domain Guidance",
 
 
119
  "cfg_interval_start": "CFG Interval Start",
120
  "cfg_interval_end": "CFG Interval End",
121
  "lm_params_title": "🤖 LM Generation Parameters",
 
116
  "audio_format_info": "Audio format for saved files",
117
  "use_adg_label": "Use ADG",
118
  "use_adg_info": "Enable Angle Domain Guidance",
119
+ "shift_label": "Shift",
120
+ "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
121
  "cfg_interval_start": "CFG Interval Start",
122
  "cfg_interval_end": "CFG Interval End",
123
  "lm_params_title": "🤖 LM Generation Parameters",
acestep/gradio_ui/i18n/ja.json CHANGED
@@ -116,6 +116,8 @@
116
  "audio_format_info": "保存ファイルのオーディオフォーマット",
117
  "use_adg_label": "ADG を使用",
118
  "use_adg_info": "角度ドメインガイダンスを有効化",
 
 
119
  "cfg_interval_start": "CFG 間隔開始",
120
  "cfg_interval_end": "CFG 間隔終了",
121
  "lm_params_title": "🤖 LM 生成パラメータ",
 
116
  "audio_format_info": "保存ファイルのオーディオフォーマット",
117
  "use_adg_label": "ADG を使用",
118
  "use_adg_info": "角度ドメインガイダンスを有効化",
119
+ "shift_label": "シフト",
120
+ "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
121
  "cfg_interval_start": "CFG 間隔開始",
122
  "cfg_interval_end": "CFG 間隔終了",
123
  "lm_params_title": "🤖 LM 生成パラメータ",
acestep/gradio_ui/i18n/zh.json CHANGED
@@ -116,6 +116,8 @@
116
  "audio_format_info": "保存文件的音频格式",
117
  "use_adg_label": "使用 ADG",
118
  "use_adg_info": "启用角域引导",
 
 
119
  "cfg_interval_start": "CFG 间隔开始",
120
  "cfg_interval_end": "CFG 间隔结束",
121
  "lm_params_title": "🤖 LM 生成参数",
 
116
  "audio_format_info": "保存文件的音频格式",
117
  "use_adg_label": "使用 ADG",
118
  "use_adg_info": "启用角域引导",
119
+ "shift_label": "Shift",
120
+ "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
121
  "cfg_interval_start": "CFG 间隔开始",
122
  "cfg_interval_end": "CFG 间隔结束",
123
  "lm_params_title": "🤖 LM 生成参数",
acestep/gradio_ui/interfaces/generation.py CHANGED
@@ -436,6 +436,15 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
436
  info=t("generation.use_adg_info"),
437
  visible=False
438
  )
 
 
 
 
 
 
 
 
 
439
 
440
  with gr.Row():
441
  cfg_interval_start = gr.Slider(
@@ -649,6 +658,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None, langua
649
  "use_adg": use_adg,
650
  "cfg_interval_start": cfg_interval_start,
651
  "cfg_interval_end": cfg_interval_end,
 
652
  "audio_format": audio_format,
653
  "output_alignment_preference": output_alignment_preference,
654
  "think_checkbox": think_checkbox,
 
436
  info=t("generation.use_adg_info"),
437
  visible=False
438
  )
439
+ shift = gr.Slider(
440
+ minimum=1.0,
441
+ maximum=5.0,
442
+ value=3.0,
443
+ step=0.1,
444
+ label=t("generation.shift_label"),
445
+ info=t("generation.shift_info"),
446
+ visible=False
447
+ )
448
 
449
  with gr.Row():
450
  cfg_interval_start = gr.Slider(
 
658
  "use_adg": use_adg,
659
  "cfg_interval_start": cfg_interval_start,
660
  "cfg_interval_end": cfg_interval_end,
661
+ "shift": shift,
662
  "audio_format": audio_format,
663
  "output_alignment_preference": output_alignment_preference,
664
  "think_checkbox": think_checkbox,
acestep/handler.py CHANGED
@@ -1785,6 +1785,7 @@ class AceStepHandler:
1785
  use_adg: bool = False,
1786
  cfg_interval_start: float = 0.0,
1787
  cfg_interval_end: float = 1.0,
 
1788
  audio_code_hints: Optional[Union[str, List[str]]] = None,
1789
  infer_method: str = "ode",
1790
  ) -> Dict[str, Any]:
@@ -1948,6 +1949,7 @@ class AceStepHandler:
1948
  "use_adg": use_adg,
1949
  "cfg_interval_start": cfg_interval_start,
1950
  "cfg_interval_end": cfg_interval_end,
 
1951
  }
1952
  logger.info("[service_generate] Generating audio...")
1953
  with self._load_model_context("model"):
@@ -2055,6 +2057,7 @@ class AceStepHandler:
2055
  use_adg: bool = False,
2056
  cfg_interval_start: float = 0.0,
2057
  cfg_interval_end: float = 1.0,
 
2058
  use_tiled_decode: bool = True,
2059
  progress=None
2060
  ) -> Dict[str, Any]:
@@ -2202,6 +2205,7 @@ class AceStepHandler:
2202
  use_adg=use_adg, # Pass use_adg parameter
2203
  cfg_interval_start=cfg_interval_start, # Pass CFG interval start
2204
  cfg_interval_end=cfg_interval_end, # Pass CFG interval end
 
2205
  audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
2206
  return_intermediate=should_return_intermediate
2207
  )
 
1785
  use_adg: bool = False,
1786
  cfg_interval_start: float = 0.0,
1787
  cfg_interval_end: float = 1.0,
1788
+ shift: float = 1.0,
1789
  audio_code_hints: Optional[Union[str, List[str]]] = None,
1790
  infer_method: str = "ode",
1791
  ) -> Dict[str, Any]:
 
1949
  "use_adg": use_adg,
1950
  "cfg_interval_start": cfg_interval_start,
1951
  "cfg_interval_end": cfg_interval_end,
1952
+ "shift": shift,
1953
  }
1954
  logger.info("[service_generate] Generating audio...")
1955
  with self._load_model_context("model"):
 
2057
  use_adg: bool = False,
2058
  cfg_interval_start: float = 0.0,
2059
  cfg_interval_end: float = 1.0,
2060
+ shift: float = 1.0,
2061
  use_tiled_decode: bool = True,
2062
  progress=None
2063
  ) -> Dict[str, Any]:
 
2205
  use_adg=use_adg, # Pass use_adg parameter
2206
  cfg_interval_start=cfg_interval_start, # Pass CFG interval start
2207
  cfg_interval_end=cfg_interval_end, # Pass CFG interval end
2208
+ shift=shift, # Pass shift parameter
2209
  audio_code_hints=audio_code_hints_batch, # Pass audio code hints as list
2210
  return_intermediate=should_return_intermediate
2211
  )
acestep/inference.py CHANGED
@@ -42,6 +42,7 @@ class GenerationParams:
42
  use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
43
  cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
44
  cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
 
45
 
46
  # Task-Specific Parameters
47
  task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
@@ -94,6 +95,7 @@ class GenerationParams:
94
  use_adg: bool = False
95
  cfg_interval_start: float = 0.0
96
  cfg_interval_end: float = 1.0
 
97
 
98
  repainting_start: float = 0.0
99
  repainting_end: float = -1
@@ -485,6 +487,7 @@ def generate_music(
485
  use_adg=params.use_adg,
486
  cfg_interval_start=params.cfg_interval_start,
487
  cfg_interval_end=params.cfg_interval_end,
 
488
  progress=progress,
489
  )
490
 
 
42
  use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
43
  cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
44
  cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
45
+ shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
46
 
47
  # Task-Specific Parameters
48
  task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
 
95
  use_adg: bool = False
96
  cfg_interval_start: float = 0.0
97
  cfg_interval_end: float = 1.0
98
+ shift: float = 1.0
99
 
100
  repainting_start: float = 0.0
101
  repainting_end: float = -1
 
487
  use_adg=params.use_adg,
488
  cfg_interval_start=params.cfg_interval_start,
489
  cfg_interval_end=params.cfg_interval_end,
490
+ shift=params.shift,
491
  progress=progress,
492
  )
493