ChuxiJ commited on
Commit
9fe9970
·
1 Parent(s): 0b990cd

refact ui (fix repaint/cover score bug)

Browse files
Dockerfile CHANGED
@@ -44,7 +44,7 @@ USER user
44
  RUN pip install --no-cache-dir --user -r requirements.txt
45
 
46
  # Install nano-vllm with --no-deps since all dependencies are already installed
47
- RUN pip install ./acestep/third_parts/nano-vllm
48
 
49
  # Copy the rest of the application
50
  COPY --chown=user:user . .
 
44
  RUN pip install --no-cache-dir --user -r requirements.txt
45
 
46
  # Install nano-vllm with --no-deps since all dependencies are already installed
47
+ RUN pip install --no-deps ./acestep/third_parts/nano-vllm
48
 
49
  # Copy the rest of the application
50
  COPY --chown=user:user . .
acestep/gradio_ui/events/__init__.py CHANGED
@@ -286,6 +286,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
286
  generation_section["simple_sample_created"],
287
  generation_section["src_audio_group"],
288
  generation_section["audio_cover_strength"],
 
289
  ]
290
  )
291
 
@@ -680,6 +681,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
680
  args_list[11] = result.duration # audio_duration
681
  # Enable thinking for Simple mode
682
  args_list[28] = True # think_checkbox
 
 
683
 
684
  # Determine which handler to use
685
  active_handler = dit_handler # Default to primary handler
 
286
  generation_section["simple_sample_created"],
287
  generation_section["src_audio_group"],
288
  generation_section["audio_cover_strength"],
289
+ generation_section["think_checkbox"], # Disable thinking for cover/repaint modes
290
  ]
291
  )
292
 
 
681
  args_list[11] = result.duration # audio_duration
682
  # Enable thinking for Simple mode
683
  args_list[28] = True # think_checkbox
684
+ # Mark as formatted caption (LM-generated sample)
685
+ args_list[36] = True # is_format_caption_state
686
 
687
  # Determine which handler to use
688
  active_handler = dit_handler # Default to primary handler
acestep/gradio_ui/events/generation_handlers.py CHANGED
@@ -710,6 +710,7 @@ def handle_generation_mode_change(mode: str):
710
  - simple_sample_created (reset state)
711
  - src_audio_group (visibility) - shown for cover and repaint
712
  - audio_cover_strength (visibility) - shown only for cover mode
 
713
  """
714
  is_simple = mode == "simple"
715
  is_custom = mode == "custom"
@@ -725,6 +726,13 @@ def handle_generation_mode_change(mode: str):
725
  }
726
  task_type_value = task_type_map.get(mode, "text2music")
727
 
 
 
 
 
 
 
 
728
  return (
729
  gr.update(visible=is_simple), # simple_mode_group
730
  gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
@@ -735,6 +743,7 @@ def handle_generation_mode_change(mode: str):
735
  False, # simple_sample_created - reset to False on mode change
736
  gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
737
  gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
 
738
  )
739
 
740
 
 
710
  - simple_sample_created (reset state)
711
  - src_audio_group (visibility) - shown for cover and repaint
712
  - audio_cover_strength (visibility) - shown only for cover mode
713
+ - think_checkbox (value and interactive) - disabled for cover/repaint modes
714
  """
715
  is_simple = mode == "simple"
716
  is_custom = mode == "custom"
 
726
  }
727
  task_type_value = task_type_map.get(mode, "text2music")
728
 
729
+ # think_checkbox: disabled and set to False for cover/repaint modes
730
+ # (these modes don't use LM thinking, they use source audio codes)
731
+ if is_cover or is_repaint:
732
+ think_checkbox_update = gr.update(value=False, interactive=False)
733
+ else:
734
+ think_checkbox_update = gr.update(value=True, interactive=True)
735
+
736
  return (
737
  gr.update(visible=is_simple), # simple_mode_group
738
  gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
 
743
  False, # simple_sample_created - reset to False on mode change
744
  gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
745
  gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
746
+ think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
747
  )
748
 
749
 
acestep/gradio_ui/events/results_handlers.py CHANGED
@@ -277,7 +277,7 @@ def _build_generation_info(
277
  avg_per_song = generation_total / num_audios if num_audios > 0 else 0
278
  gen_lines = [
279
  f"**🎵 Total generation time {songs_label}: {generation_total:.2f}s**",
280
- f"**{avg_per_song:.2f}s per song**",
281
  ]
282
  if lm_total > 0:
283
  gen_lines.append(f"- LM phase {songs_label}: {lm_total:.2f}s")
@@ -874,6 +874,9 @@ def calculate_score_handler(
874
  PMI (Pointwise Mutual Information) removes condition bias:
875
  score = log P(condition|codes) - log P(condition)
876
 
 
 
 
877
  Args:
878
  llm_handler: LLM handler instance
879
  audio_codes_str: Generated audio codes string
@@ -895,63 +898,74 @@ def calculate_score_handler(
895
  """
896
  from acestep.test_time_scaling import calculate_pmi_score_per_condition
897
 
898
- if not llm_handler.llm_initialized:
899
- return t("messages.lm_not_initialized")
900
 
901
- if not audio_codes_str or not audio_codes_str.strip():
 
 
902
  return t("messages.no_codes")
903
 
904
  try:
905
- # Build metadata dictionary from both LM metadata and user inputs
906
- metadata = {}
907
-
908
- # Priority 1: Use LM-generated metadata if available
909
- if lm_metadata and isinstance(lm_metadata, dict):
910
- metadata.update(lm_metadata)
911
-
912
- # Priority 2: Add user-provided metadata (if not already in LM metadata)
913
- if bpm is not None and 'bpm' not in metadata:
914
- try:
915
- metadata['bpm'] = int(bpm)
916
- except:
917
- pass
918
-
919
- if caption and 'caption' not in metadata:
920
- metadata['caption'] = caption
921
-
922
- if audio_duration is not None and audio_duration > 0 and 'duration' not in metadata:
923
- try:
924
- metadata['duration'] = int(audio_duration)
925
- except:
926
- pass
927
-
928
- if key_scale and key_scale.strip() and 'keyscale' not in metadata:
929
- metadata['keyscale'] = key_scale.strip()
930
-
931
- if vocal_language and vocal_language.strip() and 'language' not in metadata:
932
- metadata['language'] = vocal_language.strip()
933
-
934
- if time_signature and time_signature.strip() and 'timesignature' not in metadata:
935
- metadata['timesignature'] = time_signature.strip()
936
-
937
- # Calculate per-condition scores with appropriate metrics
938
- # - Metadata fields (bpm, duration, etc.): Top-k recall
939
- # - Caption and lyrics: PMI (normalized)
940
- scores_per_condition, global_score, status = calculate_pmi_score_per_condition(
941
- llm_handler=llm_handler,
942
- audio_codes=audio_codes_str,
943
- caption=caption or "",
944
- lyrics=lyrics or "",
945
- metadata=metadata if metadata else None,
946
- temperature=1.0,
947
- topk=10,
948
- score_scale=score_scale
949
- )
950
-
951
  alignment_report = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952
 
953
- # Only calculate if we have the handler, tensor data, and actual lyrics
954
- if dit_handler and extra_tensor_data and lyrics and lyrics.strip():
955
  try:
956
  align_result = dit_handler.get_lyric_score(
957
  pred_latent=extra_tensor_data.get('pred_latent'),
@@ -978,29 +992,46 @@ def calculate_score_handler(
978
  except Exception as e:
979
  alignment_report = f"\n⚠️ Alignment Score Error: {str(e)}"
980
 
981
- # Format display string with per-condition breakdown
982
- if global_score == 0.0 and not scores_per_condition:
983
- return t("messages.score_failed", error=status)
984
- else:
985
- # Build per-condition scores display
986
- condition_lines = []
987
- for condition_name, score_value in sorted(scores_per_condition.items()):
988
- condition_lines.append(
989
- f" • {condition_name}: {score_value:.4f}"
990
- )
991
-
992
- conditions_display = "\n".join(condition_lines) if condition_lines else " (no conditions)"
 
 
 
 
 
 
 
993
 
994
- final_output = (
995
- f"✅ Global Quality Score: {global_score:.4f} (0-1, higher=better)\n\n"
996
- f"📊 Per-Condition Scores (0-1):\n{conditions_display}\n"
997
- )
998
 
999
- if alignment_report:
1000
- final_output += alignment_report + "\n"
1001
 
1002
- final_output += "Note: Metadata uses Top-k Recall, Caption/Lyrics use PMI"
1003
- return final_output
 
 
 
 
 
 
 
 
 
 
1004
 
1005
  except Exception as e:
1006
  import traceback
 
277
  avg_per_song = generation_total / num_audios if num_audios > 0 else 0
278
  gen_lines = [
279
  f"**🎵 Total generation time {songs_label}: {generation_total:.2f}s**",
280
+ f"\n**{avg_per_song:.2f}s per song**",
281
  ]
282
  if lm_total > 0:
283
  gen_lines.append(f"- LM phase {songs_label}: {lm_total:.2f}s")
 
874
  PMI (Pointwise Mutual Information) removes condition bias:
875
  score = log P(condition|codes) - log P(condition)
876
 
877
+ For Cover/Repaint modes where audio_codes may not be available,
878
+ falls back to DiT alignment scoring only.
879
+
880
  Args:
881
  llm_handler: LLM handler instance
882
  audio_codes_str: Generated audio codes string
 
898
  """
899
  from acestep.test_time_scaling import calculate_pmi_score_per_condition
900
 
901
+ has_audio_codes = audio_codes_str and audio_codes_str.strip()
902
+ has_dit_alignment_data = dit_handler and extra_tensor_data and lyrics and lyrics.strip()
903
 
904
+ # Check if we can compute any scores
905
+ if not has_audio_codes and not has_dit_alignment_data:
906
+ # No audio codes and no DiT alignment data - can't compute any score
907
  return t("messages.no_codes")
908
 
909
  try:
910
+ scores_per_condition = {}
911
+ global_score = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
912
  alignment_report = ""
913
+
914
+ # PMI-based scoring (requires audio codes and LLM)
915
+ if has_audio_codes:
916
+ if not llm_handler.llm_initialized:
917
+ # Can still try DiT alignment if available
918
+ if not has_dit_alignment_data:
919
+ return t("messages.lm_not_initialized")
920
+ else:
921
+ # Build metadata dictionary from both LM metadata and user inputs
922
+ metadata = {}
923
+
924
+ # Priority 1: Use LM-generated metadata if available
925
+ if lm_metadata and isinstance(lm_metadata, dict):
926
+ metadata.update(lm_metadata)
927
+
928
+ # Priority 2: Add user-provided metadata (if not already in LM metadata)
929
+ if bpm is not None and 'bpm' not in metadata:
930
+ try:
931
+ metadata['bpm'] = int(bpm)
932
+ except:
933
+ pass
934
+
935
+ if caption and 'caption' not in metadata:
936
+ metadata['caption'] = caption
937
+
938
+ if audio_duration is not None and audio_duration > 0 and 'duration' not in metadata:
939
+ try:
940
+ metadata['duration'] = int(audio_duration)
941
+ except:
942
+ pass
943
+
944
+ if key_scale and key_scale.strip() and 'keyscale' not in metadata:
945
+ metadata['keyscale'] = key_scale.strip()
946
+
947
+ if vocal_language and vocal_language.strip() and 'language' not in metadata:
948
+ metadata['language'] = vocal_language.strip()
949
+
950
+ if time_signature and time_signature.strip() and 'timesignature' not in metadata:
951
+ metadata['timesignature'] = time_signature.strip()
952
+
953
+ # Calculate per-condition scores with appropriate metrics
954
+ # - Metadata fields (bpm, duration, etc.): Top-k recall
955
+ # - Caption and lyrics: PMI (normalized)
956
+ scores_per_condition, global_score, status = calculate_pmi_score_per_condition(
957
+ llm_handler=llm_handler,
958
+ audio_codes=audio_codes_str,
959
+ caption=caption or "",
960
+ lyrics=lyrics or "",
961
+ metadata=metadata if metadata else None,
962
+ temperature=1.0,
963
+ topk=10,
964
+ score_scale=score_scale
965
+ )
966
 
967
+ # DiT alignment scoring (works even without audio codes - for Cover/Repaint modes)
968
+ if has_dit_alignment_data:
969
  try:
970
  align_result = dit_handler.get_lyric_score(
971
  pred_latent=extra_tensor_data.get('pred_latent'),
 
992
  except Exception as e:
993
  alignment_report = f"\n⚠️ Alignment Score Error: {str(e)}"
994
 
995
+ # Format display string
996
+ if has_audio_codes and llm_handler.llm_initialized:
997
+ # Full scoring with PMI + alignment
998
+ if global_score == 0.0 and not scores_per_condition:
999
+ # PMI scoring failed but we might have alignment
1000
+ if alignment_report and not alignment_report.startswith("\n⚠️"):
1001
+ final_output = "📊 DiT Alignment Scores (LM codes not available):\n"
1002
+ final_output += alignment_report
1003
+ return final_output
1004
+ return t("messages.score_failed", error="PMI scoring returned no results")
1005
+ else:
1006
+ # Build per-condition scores display
1007
+ condition_lines = []
1008
+ for condition_name, score_value in sorted(scores_per_condition.items()):
1009
+ condition_lines.append(
1010
+ f" • {condition_name}: {score_value:.4f}"
1011
+ )
1012
+
1013
+ conditions_display = "\n".join(condition_lines) if condition_lines else " (no conditions)"
1014
 
1015
+ final_output = (
1016
+ f"✅ Global Quality Score: {global_score:.4f} (0-1, higher=better)\n\n"
1017
+ f"📊 Per-Condition Scores (0-1):\n{conditions_display}\n"
1018
+ )
1019
 
1020
+ if alignment_report:
1021
+ final_output += alignment_report + "\n"
1022
 
1023
+ final_output += "Note: Metadata uses Top-k Recall, Caption/Lyrics use PMI"
1024
+ return final_output
1025
+ else:
1026
+ # Only DiT alignment available (Cover/Repaint mode fallback)
1027
+ if alignment_report and not alignment_report.startswith("\n⚠️"):
1028
+ final_output = "📊 DiT Alignment Scores (LM codes not available for Cover/Repaint mode):\n"
1029
+ final_output += alignment_report
1030
+ return final_output
1031
+ elif alignment_report:
1032
+ return alignment_report
1033
+ else:
1034
+ return "⚠️ No scoring data available"
1035
 
1036
  except Exception as e:
1037
  import traceback