Spaces:
Running
on
A100
Running
on
A100
refact ui (fix repaint/cover score bug)
Browse files
Dockerfile
CHANGED
|
@@ -44,7 +44,7 @@ USER user
|
|
| 44 |
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 45 |
|
| 46 |
# Install nano-vllm with --no-deps since all dependencies are already installed
|
| 47 |
-
RUN pip install ./acestep/third_parts/nano-vllm
|
| 48 |
|
| 49 |
# Copy the rest of the application
|
| 50 |
COPY --chown=user:user . .
|
|
|
|
| 44 |
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 45 |
|
| 46 |
# Install nano-vllm with --no-deps since all dependencies are already installed
|
| 47 |
+
RUN pip install --no-deps ./acestep/third_parts/nano-vllm
|
| 48 |
|
| 49 |
# Copy the rest of the application
|
| 50 |
COPY --chown=user:user . .
|
acestep/gradio_ui/events/__init__.py
CHANGED
|
@@ -286,6 +286,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 286 |
generation_section["simple_sample_created"],
|
| 287 |
generation_section["src_audio_group"],
|
| 288 |
generation_section["audio_cover_strength"],
|
|
|
|
| 289 |
]
|
| 290 |
)
|
| 291 |
|
|
@@ -680,6 +681,8 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 680 |
args_list[11] = result.duration # audio_duration
|
| 681 |
# Enable thinking for Simple mode
|
| 682 |
args_list[28] = True # think_checkbox
|
|
|
|
|
|
|
| 683 |
|
| 684 |
# Determine which handler to use
|
| 685 |
active_handler = dit_handler # Default to primary handler
|
|
|
|
| 286 |
generation_section["simple_sample_created"],
|
| 287 |
generation_section["src_audio_group"],
|
| 288 |
generation_section["audio_cover_strength"],
|
| 289 |
+
generation_section["think_checkbox"], # Disable thinking for cover/repaint modes
|
| 290 |
]
|
| 291 |
)
|
| 292 |
|
|
|
|
| 681 |
args_list[11] = result.duration # audio_duration
|
| 682 |
# Enable thinking for Simple mode
|
| 683 |
args_list[28] = True # think_checkbox
|
| 684 |
+
# Mark as formatted caption (LM-generated sample)
|
| 685 |
+
args_list[36] = True # is_format_caption_state
|
| 686 |
|
| 687 |
# Determine which handler to use
|
| 688 |
active_handler = dit_handler # Default to primary handler
|
acestep/gradio_ui/events/generation_handlers.py
CHANGED
|
@@ -710,6 +710,7 @@ def handle_generation_mode_change(mode: str):
|
|
| 710 |
- simple_sample_created (reset state)
|
| 711 |
- src_audio_group (visibility) - shown for cover and repaint
|
| 712 |
- audio_cover_strength (visibility) - shown only for cover mode
|
|
|
|
| 713 |
"""
|
| 714 |
is_simple = mode == "simple"
|
| 715 |
is_custom = mode == "custom"
|
|
@@ -725,6 +726,13 @@ def handle_generation_mode_change(mode: str):
|
|
| 725 |
}
|
| 726 |
task_type_value = task_type_map.get(mode, "text2music")
|
| 727 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
return (
|
| 729 |
gr.update(visible=is_simple), # simple_mode_group
|
| 730 |
gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
|
|
@@ -735,6 +743,7 @@ def handle_generation_mode_change(mode: str):
|
|
| 735 |
False, # simple_sample_created - reset to False on mode change
|
| 736 |
gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
|
| 737 |
gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
|
|
|
|
| 738 |
)
|
| 739 |
|
| 740 |
|
|
|
|
| 710 |
- simple_sample_created (reset state)
|
| 711 |
- src_audio_group (visibility) - shown for cover and repaint
|
| 712 |
- audio_cover_strength (visibility) - shown only for cover mode
|
| 713 |
+
- think_checkbox (value and interactive) - disabled for cover/repaint modes
|
| 714 |
"""
|
| 715 |
is_simple = mode == "simple"
|
| 716 |
is_custom = mode == "custom"
|
|
|
|
| 726 |
}
|
| 727 |
task_type_value = task_type_map.get(mode, "text2music")
|
| 728 |
|
| 729 |
+
# think_checkbox: disabled and set to False for cover/repaint modes
|
| 730 |
+
# (these modes don't use LM thinking, they use source audio codes)
|
| 731 |
+
if is_cover or is_repaint:
|
| 732 |
+
think_checkbox_update = gr.update(value=False, interactive=False)
|
| 733 |
+
else:
|
| 734 |
+
think_checkbox_update = gr.update(value=True, interactive=True)
|
| 735 |
+
|
| 736 |
return (
|
| 737 |
gr.update(visible=is_simple), # simple_mode_group
|
| 738 |
gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
|
|
|
|
| 743 |
False, # simple_sample_created - reset to False on mode change
|
| 744 |
gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
|
| 745 |
gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
|
| 746 |
+
think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
|
| 747 |
)
|
| 748 |
|
| 749 |
|
acestep/gradio_ui/events/results_handlers.py
CHANGED
|
@@ -277,7 +277,7 @@ def _build_generation_info(
|
|
| 277 |
avg_per_song = generation_total / num_audios if num_audios > 0 else 0
|
| 278 |
gen_lines = [
|
| 279 |
f"**🎵 Total generation time {songs_label}: {generation_total:.2f}s**",
|
| 280 |
-
f"**{avg_per_song:.2f}s per song**",
|
| 281 |
]
|
| 282 |
if lm_total > 0:
|
| 283 |
gen_lines.append(f"- LM phase {songs_label}: {lm_total:.2f}s")
|
|
@@ -874,6 +874,9 @@ def calculate_score_handler(
|
|
| 874 |
PMI (Pointwise Mutual Information) removes condition bias:
|
| 875 |
score = log P(condition|codes) - log P(condition)
|
| 876 |
|
|
|
|
|
|
|
|
|
|
| 877 |
Args:
|
| 878 |
llm_handler: LLM handler instance
|
| 879 |
audio_codes_str: Generated audio codes string
|
|
@@ -895,63 +898,74 @@ def calculate_score_handler(
|
|
| 895 |
"""
|
| 896 |
from acestep.test_time_scaling import calculate_pmi_score_per_condition
|
| 897 |
|
| 898 |
-
|
| 899 |
-
|
| 900 |
|
| 901 |
-
if
|
|
|
|
|
|
|
| 902 |
return t("messages.no_codes")
|
| 903 |
|
| 904 |
try:
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
# Priority 1: Use LM-generated metadata if available
|
| 909 |
-
if lm_metadata and isinstance(lm_metadata, dict):
|
| 910 |
-
metadata.update(lm_metadata)
|
| 911 |
-
|
| 912 |
-
# Priority 2: Add user-provided metadata (if not already in LM metadata)
|
| 913 |
-
if bpm is not None and 'bpm' not in metadata:
|
| 914 |
-
try:
|
| 915 |
-
metadata['bpm'] = int(bpm)
|
| 916 |
-
except:
|
| 917 |
-
pass
|
| 918 |
-
|
| 919 |
-
if caption and 'caption' not in metadata:
|
| 920 |
-
metadata['caption'] = caption
|
| 921 |
-
|
| 922 |
-
if audio_duration is not None and audio_duration > 0 and 'duration' not in metadata:
|
| 923 |
-
try:
|
| 924 |
-
metadata['duration'] = int(audio_duration)
|
| 925 |
-
except:
|
| 926 |
-
pass
|
| 927 |
-
|
| 928 |
-
if key_scale and key_scale.strip() and 'keyscale' not in metadata:
|
| 929 |
-
metadata['keyscale'] = key_scale.strip()
|
| 930 |
-
|
| 931 |
-
if vocal_language and vocal_language.strip() and 'language' not in metadata:
|
| 932 |
-
metadata['language'] = vocal_language.strip()
|
| 933 |
-
|
| 934 |
-
if time_signature and time_signature.strip() and 'timesignature' not in metadata:
|
| 935 |
-
metadata['timesignature'] = time_signature.strip()
|
| 936 |
-
|
| 937 |
-
# Calculate per-condition scores with appropriate metrics
|
| 938 |
-
# - Metadata fields (bpm, duration, etc.): Top-k recall
|
| 939 |
-
# - Caption and lyrics: PMI (normalized)
|
| 940 |
-
scores_per_condition, global_score, status = calculate_pmi_score_per_condition(
|
| 941 |
-
llm_handler=llm_handler,
|
| 942 |
-
audio_codes=audio_codes_str,
|
| 943 |
-
caption=caption or "",
|
| 944 |
-
lyrics=lyrics or "",
|
| 945 |
-
metadata=metadata if metadata else None,
|
| 946 |
-
temperature=1.0,
|
| 947 |
-
topk=10,
|
| 948 |
-
score_scale=score_scale
|
| 949 |
-
)
|
| 950 |
-
|
| 951 |
alignment_report = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
|
| 953 |
-
#
|
| 954 |
-
if
|
| 955 |
try:
|
| 956 |
align_result = dit_handler.get_lyric_score(
|
| 957 |
pred_latent=extra_tensor_data.get('pred_latent'),
|
|
@@ -978,29 +992,46 @@ def calculate_score_handler(
|
|
| 978 |
except Exception as e:
|
| 979 |
alignment_report = f"\n⚠️ Alignment Score Error: {str(e)}"
|
| 980 |
|
| 981 |
-
# Format display string
|
| 982 |
-
if
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
)
|
| 991 |
-
|
| 992 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 993 |
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
|
| 1002 |
-
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
|
| 1005 |
except Exception as e:
|
| 1006 |
import traceback
|
|
|
|
| 277 |
avg_per_song = generation_total / num_audios if num_audios > 0 else 0
|
| 278 |
gen_lines = [
|
| 279 |
f"**🎵 Total generation time {songs_label}: {generation_total:.2f}s**",
|
| 280 |
+
f"\n**{avg_per_song:.2f}s per song**",
|
| 281 |
]
|
| 282 |
if lm_total > 0:
|
| 283 |
gen_lines.append(f"- LM phase {songs_label}: {lm_total:.2f}s")
|
|
|
|
| 874 |
PMI (Pointwise Mutual Information) removes condition bias:
|
| 875 |
score = log P(condition|codes) - log P(condition)
|
| 876 |
|
| 877 |
+
For Cover/Repaint modes where audio_codes may not be available,
|
| 878 |
+
falls back to DiT alignment scoring only.
|
| 879 |
+
|
| 880 |
Args:
|
| 881 |
llm_handler: LLM handler instance
|
| 882 |
audio_codes_str: Generated audio codes string
|
|
|
|
| 898 |
"""
|
| 899 |
from acestep.test_time_scaling import calculate_pmi_score_per_condition
|
| 900 |
|
| 901 |
+
has_audio_codes = audio_codes_str and audio_codes_str.strip()
|
| 902 |
+
has_dit_alignment_data = dit_handler and extra_tensor_data and lyrics and lyrics.strip()
|
| 903 |
|
| 904 |
+
# Check if we can compute any scores
|
| 905 |
+
if not has_audio_codes and not has_dit_alignment_data:
|
| 906 |
+
# No audio codes and no DiT alignment data - can't compute any score
|
| 907 |
return t("messages.no_codes")
|
| 908 |
|
| 909 |
try:
|
| 910 |
+
scores_per_condition = {}
|
| 911 |
+
global_score = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
alignment_report = ""
|
| 913 |
+
|
| 914 |
+
# PMI-based scoring (requires audio codes and LLM)
|
| 915 |
+
if has_audio_codes:
|
| 916 |
+
if not llm_handler.llm_initialized:
|
| 917 |
+
# Can still try DiT alignment if available
|
| 918 |
+
if not has_dit_alignment_data:
|
| 919 |
+
return t("messages.lm_not_initialized")
|
| 920 |
+
else:
|
| 921 |
+
# Build metadata dictionary from both LM metadata and user inputs
|
| 922 |
+
metadata = {}
|
| 923 |
+
|
| 924 |
+
# Priority 1: Use LM-generated metadata if available
|
| 925 |
+
if lm_metadata and isinstance(lm_metadata, dict):
|
| 926 |
+
metadata.update(lm_metadata)
|
| 927 |
+
|
| 928 |
+
# Priority 2: Add user-provided metadata (if not already in LM metadata)
|
| 929 |
+
if bpm is not None and 'bpm' not in metadata:
|
| 930 |
+
try:
|
| 931 |
+
metadata['bpm'] = int(bpm)
|
| 932 |
+
except:
|
| 933 |
+
pass
|
| 934 |
+
|
| 935 |
+
if caption and 'caption' not in metadata:
|
| 936 |
+
metadata['caption'] = caption
|
| 937 |
+
|
| 938 |
+
if audio_duration is not None and audio_duration > 0 and 'duration' not in metadata:
|
| 939 |
+
try:
|
| 940 |
+
metadata['duration'] = int(audio_duration)
|
| 941 |
+
except:
|
| 942 |
+
pass
|
| 943 |
+
|
| 944 |
+
if key_scale and key_scale.strip() and 'keyscale' not in metadata:
|
| 945 |
+
metadata['keyscale'] = key_scale.strip()
|
| 946 |
+
|
| 947 |
+
if vocal_language and vocal_language.strip() and 'language' not in metadata:
|
| 948 |
+
metadata['language'] = vocal_language.strip()
|
| 949 |
+
|
| 950 |
+
if time_signature and time_signature.strip() and 'timesignature' not in metadata:
|
| 951 |
+
metadata['timesignature'] = time_signature.strip()
|
| 952 |
+
|
| 953 |
+
# Calculate per-condition scores with appropriate metrics
|
| 954 |
+
# - Metadata fields (bpm, duration, etc.): Top-k recall
|
| 955 |
+
# - Caption and lyrics: PMI (normalized)
|
| 956 |
+
scores_per_condition, global_score, status = calculate_pmi_score_per_condition(
|
| 957 |
+
llm_handler=llm_handler,
|
| 958 |
+
audio_codes=audio_codes_str,
|
| 959 |
+
caption=caption or "",
|
| 960 |
+
lyrics=lyrics or "",
|
| 961 |
+
metadata=metadata if metadata else None,
|
| 962 |
+
temperature=1.0,
|
| 963 |
+
topk=10,
|
| 964 |
+
score_scale=score_scale
|
| 965 |
+
)
|
| 966 |
|
| 967 |
+
# DiT alignment scoring (works even without audio codes - for Cover/Repaint modes)
|
| 968 |
+
if has_dit_alignment_data:
|
| 969 |
try:
|
| 970 |
align_result = dit_handler.get_lyric_score(
|
| 971 |
pred_latent=extra_tensor_data.get('pred_latent'),
|
|
|
|
| 992 |
except Exception as e:
|
| 993 |
alignment_report = f"\n⚠️ Alignment Score Error: {str(e)}"
|
| 994 |
|
| 995 |
+
# Format display string
|
| 996 |
+
if has_audio_codes and llm_handler.llm_initialized:
|
| 997 |
+
# Full scoring with PMI + alignment
|
| 998 |
+
if global_score == 0.0 and not scores_per_condition:
|
| 999 |
+
# PMI scoring failed but we might have alignment
|
| 1000 |
+
if alignment_report and not alignment_report.startswith("\n⚠️"):
|
| 1001 |
+
final_output = "📊 DiT Alignment Scores (LM codes not available):\n"
|
| 1002 |
+
final_output += alignment_report
|
| 1003 |
+
return final_output
|
| 1004 |
+
return t("messages.score_failed", error="PMI scoring returned no results")
|
| 1005 |
+
else:
|
| 1006 |
+
# Build per-condition scores display
|
| 1007 |
+
condition_lines = []
|
| 1008 |
+
for condition_name, score_value in sorted(scores_per_condition.items()):
|
| 1009 |
+
condition_lines.append(
|
| 1010 |
+
f" • {condition_name}: {score_value:.4f}"
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
conditions_display = "\n".join(condition_lines) if condition_lines else " (no conditions)"
|
| 1014 |
|
| 1015 |
+
final_output = (
|
| 1016 |
+
f"✅ Global Quality Score: {global_score:.4f} (0-1, higher=better)\n\n"
|
| 1017 |
+
f"📊 Per-Condition Scores (0-1):\n{conditions_display}\n"
|
| 1018 |
+
)
|
| 1019 |
|
| 1020 |
+
if alignment_report:
|
| 1021 |
+
final_output += alignment_report + "\n"
|
| 1022 |
|
| 1023 |
+
final_output += "Note: Metadata uses Top-k Recall, Caption/Lyrics use PMI"
|
| 1024 |
+
return final_output
|
| 1025 |
+
else:
|
| 1026 |
+
# Only DiT alignment available (Cover/Repaint mode fallback)
|
| 1027 |
+
if alignment_report and not alignment_report.startswith("\n⚠️"):
|
| 1028 |
+
final_output = "📊 DiT Alignment Scores (LM codes not available for Cover/Repaint mode):\n"
|
| 1029 |
+
final_output += alignment_report
|
| 1030 |
+
return final_output
|
| 1031 |
+
elif alignment_report:
|
| 1032 |
+
return alignment_report
|
| 1033 |
+
else:
|
| 1034 |
+
return "⚠️ No scoring data available"
|
| 1035 |
|
| 1036 |
except Exception as e:
|
| 1037 |
import traceback
|