Spaces:
Running
on
A100
Running
on
A100
remove no use import
Browse files- acestep/gradio_ui.py +105 -0
- acestep/handler.py +42 -5
acestep/gradio_ui.py
CHANGED
|
@@ -761,4 +761,109 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
|
|
| 761 |
],
|
| 762 |
outputs=[generation_section["text2music_audio_code_string"]]
|
| 763 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
|
|
|
|
| 761 |
],
|
| 762 |
outputs=[generation_section["text2music_audio_code_string"]]
|
| 763 |
)
|
| 764 |
+
|
| 765 |
+
# Update instruction and UI visibility based on task type
|
| 766 |
+
def update_instruction_ui(
|
| 767 |
+
task_type_value: str,
|
| 768 |
+
track_name_value: Optional[str],
|
| 769 |
+
complete_track_classes_value: list,
|
| 770 |
+
audio_codes_content: str = ""
|
| 771 |
+
) -> tuple:
|
| 772 |
+
"""Update instruction and UI visibility based on task type."""
|
| 773 |
+
instruction = handler.generate_instruction(
|
| 774 |
+
task_type=task_type_value,
|
| 775 |
+
track_name=track_name_value,
|
| 776 |
+
complete_track_classes=complete_track_classes_value
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
# Show track_name for lego and extract
|
| 780 |
+
track_name_visible = task_type_value in ["lego", "extract"]
|
| 781 |
+
# Show complete_track_classes for complete
|
| 782 |
+
complete_visible = task_type_value == "complete"
|
| 783 |
+
# Show audio_cover_strength for cover
|
| 784 |
+
audio_cover_strength_visible = task_type_value == "cover"
|
| 785 |
+
# Show audio_code_string for cover
|
| 786 |
+
audio_code_visible = task_type_value == "cover"
|
| 787 |
+
# Show repainting controls for repaint and lego
|
| 788 |
+
repainting_visible = task_type_value in ["repaint", "lego"]
|
| 789 |
+
# Show use_5hz_lm, lm_temperature for text2music
|
| 790 |
+
use_5hz_lm_visible = task_type_value == "text2music"
|
| 791 |
+
# Show text2music_audio_codes if task is text2music OR if it has content
|
| 792 |
+
# This allows it to stay visible even if user switches task type but has codes
|
| 793 |
+
has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
|
| 794 |
+
text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
|
| 795 |
+
|
| 796 |
+
return (
|
| 797 |
+
instruction, # instruction_display_gen
|
| 798 |
+
gr.update(visible=track_name_visible), # track_name
|
| 799 |
+
gr.update(visible=complete_visible), # complete_track_classes
|
| 800 |
+
gr.update(visible=audio_cover_strength_visible), # audio_cover_strength
|
| 801 |
+
gr.update(visible=repainting_visible), # repainting_group
|
| 802 |
+
gr.update(visible=audio_code_visible), # audio_code_string
|
| 803 |
+
gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
|
| 804 |
+
gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# Bind update_instruction_ui to task_type, track_name, and complete_track_classes changes
|
| 808 |
+
generation_section["task_type"].change(
|
| 809 |
+
fn=update_instruction_ui,
|
| 810 |
+
inputs=[
|
| 811 |
+
generation_section["task_type"],
|
| 812 |
+
generation_section["track_name"],
|
| 813 |
+
generation_section["complete_track_classes"],
|
| 814 |
+
generation_section["text2music_audio_code_string"]
|
| 815 |
+
],
|
| 816 |
+
outputs=[
|
| 817 |
+
generation_section["instruction_display_gen"],
|
| 818 |
+
generation_section["track_name"],
|
| 819 |
+
generation_section["complete_track_classes"],
|
| 820 |
+
generation_section["audio_cover_strength"],
|
| 821 |
+
generation_section["repainting_group"],
|
| 822 |
+
generation_section["audio_code_string"],
|
| 823 |
+
generation_section["use_5hz_lm_row"],
|
| 824 |
+
generation_section["text2music_audio_codes_group"],
|
| 825 |
+
]
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
# Also update instruction when track_name changes (for lego/extract tasks)
|
| 829 |
+
generation_section["track_name"].change(
|
| 830 |
+
fn=update_instruction_ui,
|
| 831 |
+
inputs=[
|
| 832 |
+
generation_section["task_type"],
|
| 833 |
+
generation_section["track_name"],
|
| 834 |
+
generation_section["complete_track_classes"],
|
| 835 |
+
generation_section["text2music_audio_code_string"]
|
| 836 |
+
],
|
| 837 |
+
outputs=[
|
| 838 |
+
generation_section["instruction_display_gen"],
|
| 839 |
+
generation_section["track_name"],
|
| 840 |
+
generation_section["complete_track_classes"],
|
| 841 |
+
generation_section["audio_cover_strength"],
|
| 842 |
+
generation_section["repainting_group"],
|
| 843 |
+
generation_section["audio_code_string"],
|
| 844 |
+
generation_section["use_5hz_lm_row"],
|
| 845 |
+
generation_section["text2music_audio_codes_group"],
|
| 846 |
+
]
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
# Also update instruction when complete_track_classes changes (for complete task)
|
| 850 |
+
generation_section["complete_track_classes"].change(
|
| 851 |
+
fn=update_instruction_ui,
|
| 852 |
+
inputs=[
|
| 853 |
+
generation_section["task_type"],
|
| 854 |
+
generation_section["track_name"],
|
| 855 |
+
generation_section["complete_track_classes"],
|
| 856 |
+
generation_section["text2music_audio_code_string"]
|
| 857 |
+
],
|
| 858 |
+
outputs=[
|
| 859 |
+
generation_section["instruction_display_gen"],
|
| 860 |
+
generation_section["track_name"],
|
| 861 |
+
generation_section["complete_track_classes"],
|
| 862 |
+
generation_section["audio_cover_strength"],
|
| 863 |
+
generation_section["repainting_group"],
|
| 864 |
+
generation_section["audio_code_string"],
|
| 865 |
+
generation_section["use_5hz_lm_row"],
|
| 866 |
+
generation_section["text2music_audio_codes_group"],
|
| 867 |
+
]
|
| 868 |
+
)
|
| 869 |
|
acestep/handler.py
CHANGED
|
@@ -4,7 +4,6 @@ Encapsulates all data processing and business logic as a bridge between model an
|
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import math
|
| 7 |
-
import glob
|
| 8 |
import tempfile
|
| 9 |
import traceback
|
| 10 |
import re
|
|
@@ -12,10 +11,6 @@ import random
|
|
| 12 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 13 |
|
| 14 |
import torch
|
| 15 |
-
import torch.nn.functional as F
|
| 16 |
-
import matplotlib.pyplot as plt
|
| 17 |
-
import numpy as np
|
| 18 |
-
import scipy.io.wavfile as wavfile
|
| 19 |
import torchaudio
|
| 20 |
import soundfile as sf
|
| 21 |
import time
|
|
@@ -666,6 +661,48 @@ class AceStepHandler:
|
|
| 666 |
|
| 667 |
def is_silence(self, audio):
|
| 668 |
return torch.all(audio.abs() < 1e-6)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
|
| 670 |
def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 671 |
if audio_file is None:
|
|
|
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import math
|
|
|
|
| 7 |
import tempfile
|
| 8 |
import traceback
|
| 9 |
import re
|
|
|
|
| 11 |
from typing import Optional, Dict, Any, Tuple, List, Union
|
| 12 |
|
| 13 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
import torchaudio
|
| 15 |
import soundfile as sf
|
| 16 |
import time
|
|
|
|
| 661 |
|
| 662 |
def is_silence(self, audio):
|
| 663 |
return torch.all(audio.abs() < 1e-6)
|
| 664 |
+
|
| 665 |
+
def generate_instruction(
|
| 666 |
+
self,
|
| 667 |
+
task_type: str,
|
| 668 |
+
track_name: Optional[str] = None,
|
| 669 |
+
complete_track_classes: Optional[List[str]] = None
|
| 670 |
+
) -> str:
|
| 671 |
+
TRACK_NAMES = [
|
| 672 |
+
"woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 673 |
+
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
|
| 674 |
+
]
|
| 675 |
+
|
| 676 |
+
if task_type == "text2music":
|
| 677 |
+
return "Fill the audio semantic mask based on the given conditions:"
|
| 678 |
+
elif task_type == "repaint":
|
| 679 |
+
return "Repaint the mask area based on the given conditions:"
|
| 680 |
+
elif task_type == "cover":
|
| 681 |
+
return "Generate audio semantic tokens based on the given conditions:"
|
| 682 |
+
elif task_type == "extract":
|
| 683 |
+
if track_name:
|
| 684 |
+
# Convert to uppercase
|
| 685 |
+
track_name_upper = track_name.upper()
|
| 686 |
+
return f"Extract the {track_name_upper} track from the audio:"
|
| 687 |
+
else:
|
| 688 |
+
return "Extract the track from the audio:"
|
| 689 |
+
elif task_type == "lego":
|
| 690 |
+
if track_name:
|
| 691 |
+
# Convert to uppercase
|
| 692 |
+
track_name_upper = track_name.upper()
|
| 693 |
+
return f"Generate the {track_name_upper} track based on the audio context:"
|
| 694 |
+
else:
|
| 695 |
+
return "Generate the track based on the audio context:"
|
| 696 |
+
elif task_type == "complete":
|
| 697 |
+
if complete_track_classes and len(complete_track_classes) > 0:
|
| 698 |
+
# Convert to uppercase and join with " | "
|
| 699 |
+
track_classes_upper = [t.upper() for t in complete_track_classes]
|
| 700 |
+
complete_track_classes_str = " | ".join(track_classes_upper)
|
| 701 |
+
return f"Complete the input track with {complete_track_classes_str}:"
|
| 702 |
+
else:
|
| 703 |
+
return "Complete the input track:"
|
| 704 |
+
else:
|
| 705 |
+
return "Fill the audio semantic mask based on the given conditions:"
|
| 706 |
|
| 707 |
def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
|
| 708 |
if audio_file is None:
|