ChuxiJ commited on
Commit
09e4f6f
·
1 Parent(s): 4166944

remove no use import

Browse files
Files changed (2) hide show
  1. acestep/gradio_ui.py +105 -0
  2. acestep/handler.py +42 -5
acestep/gradio_ui.py CHANGED
@@ -761,4 +761,109 @@ def setup_event_handlers(demo, handler, dataset_section, generation_section, res
761
  ],
762
  outputs=[generation_section["text2music_audio_code_string"]]
763
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
 
 
761
  ],
762
  outputs=[generation_section["text2music_audio_code_string"]]
763
  )
764
+
765
+ # Update instruction and UI visibility based on task type
766
+ def update_instruction_ui(
767
+ task_type_value: str,
768
+ track_name_value: Optional[str],
769
+ complete_track_classes_value: list,
770
+ audio_codes_content: str = ""
771
+ ) -> tuple:
772
+ """Update instruction and UI visibility based on task type."""
773
+ instruction = handler.generate_instruction(
774
+ task_type=task_type_value,
775
+ track_name=track_name_value,
776
+ complete_track_classes=complete_track_classes_value
777
+ )
778
+
779
+ # Show track_name for lego and extract
780
+ track_name_visible = task_type_value in ["lego", "extract"]
781
+ # Show complete_track_classes for complete
782
+ complete_visible = task_type_value == "complete"
783
+ # Show audio_cover_strength for cover
784
+ audio_cover_strength_visible = task_type_value == "cover"
785
+ # Show audio_code_string for cover
786
+ audio_code_visible = task_type_value == "cover"
787
+ # Show repainting controls for repaint and lego
788
+ repainting_visible = task_type_value in ["repaint", "lego"]
789
+ # Show use_5hz_lm, lm_temperature for text2music
790
+ use_5hz_lm_visible = task_type_value == "text2music"
791
+ # Show text2music_audio_codes if task is text2music OR if it has content
792
+ # This allows it to stay visible even if user switches task type but has codes
793
+ has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
794
+ text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
795
+
796
+ return (
797
+ instruction, # instruction_display_gen
798
+ gr.update(visible=track_name_visible), # track_name
799
+ gr.update(visible=complete_visible), # complete_track_classes
800
+ gr.update(visible=audio_cover_strength_visible), # audio_cover_strength
801
+ gr.update(visible=repainting_visible), # repainting_group
802
+ gr.update(visible=audio_code_visible), # audio_code_string
803
+ gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
804
+ gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
805
+ )
806
+
807
+ # Bind update_instruction_ui to task_type, track_name, and complete_track_classes changes
808
+ generation_section["task_type"].change(
809
+ fn=update_instruction_ui,
810
+ inputs=[
811
+ generation_section["task_type"],
812
+ generation_section["track_name"],
813
+ generation_section["complete_track_classes"],
814
+ generation_section["text2music_audio_code_string"]
815
+ ],
816
+ outputs=[
817
+ generation_section["instruction_display_gen"],
818
+ generation_section["track_name"],
819
+ generation_section["complete_track_classes"],
820
+ generation_section["audio_cover_strength"],
821
+ generation_section["repainting_group"],
822
+ generation_section["audio_code_string"],
823
+ generation_section["use_5hz_lm_row"],
824
+ generation_section["text2music_audio_codes_group"],
825
+ ]
826
+ )
827
+
828
+ # Also update instruction when track_name changes (for lego/extract tasks)
829
+ generation_section["track_name"].change(
830
+ fn=update_instruction_ui,
831
+ inputs=[
832
+ generation_section["task_type"],
833
+ generation_section["track_name"],
834
+ generation_section["complete_track_classes"],
835
+ generation_section["text2music_audio_code_string"]
836
+ ],
837
+ outputs=[
838
+ generation_section["instruction_display_gen"],
839
+ generation_section["track_name"],
840
+ generation_section["complete_track_classes"],
841
+ generation_section["audio_cover_strength"],
842
+ generation_section["repainting_group"],
843
+ generation_section["audio_code_string"],
844
+ generation_section["use_5hz_lm_row"],
845
+ generation_section["text2music_audio_codes_group"],
846
+ ]
847
+ )
848
+
849
+ # Also update instruction when complete_track_classes changes (for complete task)
850
+ generation_section["complete_track_classes"].change(
851
+ fn=update_instruction_ui,
852
+ inputs=[
853
+ generation_section["task_type"],
854
+ generation_section["track_name"],
855
+ generation_section["complete_track_classes"],
856
+ generation_section["text2music_audio_code_string"]
857
+ ],
858
+ outputs=[
859
+ generation_section["instruction_display_gen"],
860
+ generation_section["track_name"],
861
+ generation_section["complete_track_classes"],
862
+ generation_section["audio_cover_strength"],
863
+ generation_section["repainting_group"],
864
+ generation_section["audio_code_string"],
865
+ generation_section["use_5hz_lm_row"],
866
+ generation_section["text2music_audio_codes_group"],
867
+ ]
868
+ )
869
 
acestep/handler.py CHANGED
@@ -4,7 +4,6 @@ Encapsulates all data processing and business logic as a bridge between model an
4
  """
5
  import os
6
  import math
7
- import glob
8
  import tempfile
9
  import traceback
10
  import re
@@ -12,10 +11,6 @@ import random
12
  from typing import Optional, Dict, Any, Tuple, List, Union
13
 
14
  import torch
15
- import torch.nn.functional as F
16
- import matplotlib.pyplot as plt
17
- import numpy as np
18
- import scipy.io.wavfile as wavfile
19
  import torchaudio
20
  import soundfile as sf
21
  import time
@@ -666,6 +661,48 @@ class AceStepHandler:
666
 
667
  def is_silence(self, audio):
668
  return torch.all(audio.abs() < 1e-6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
  def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
671
  if audio_file is None:
 
4
  """
5
  import os
6
  import math
 
7
  import tempfile
8
  import traceback
9
  import re
 
11
  from typing import Optional, Dict, Any, Tuple, List, Union
12
 
13
  import torch
 
 
 
 
14
  import torchaudio
15
  import soundfile as sf
16
  import time
 
661
 
662
  def is_silence(self, audio):
663
  return torch.all(audio.abs() < 1e-6)
664
+
665
+ def generate_instruction(
666
+ self,
667
+ task_type: str,
668
+ track_name: Optional[str] = None,
669
+ complete_track_classes: Optional[List[str]] = None
670
+ ) -> str:
671
+ TRACK_NAMES = [
672
+ "woodwinds", "brass", "fx", "synth", "strings", "percussion",
673
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
674
+ ]
675
+
676
+ if task_type == "text2music":
677
+ return "Fill the audio semantic mask based on the given conditions:"
678
+ elif task_type == "repaint":
679
+ return "Repaint the mask area based on the given conditions:"
680
+ elif task_type == "cover":
681
+ return "Generate audio semantic tokens based on the given conditions:"
682
+ elif task_type == "extract":
683
+ if track_name:
684
+ # Convert to uppercase
685
+ track_name_upper = track_name.upper()
686
+ return f"Extract the {track_name_upper} track from the audio:"
687
+ else:
688
+ return "Extract the track from the audio:"
689
+ elif task_type == "lego":
690
+ if track_name:
691
+ # Convert to uppercase
692
+ track_name_upper = track_name.upper()
693
+ return f"Generate the {track_name_upper} track based on the audio context:"
694
+ else:
695
+ return "Generate the track based on the audio context:"
696
+ elif task_type == "complete":
697
+ if complete_track_classes and len(complete_track_classes) > 0:
698
+ # Convert to uppercase and join with " | "
699
+ track_classes_upper = [t.upper() for t in complete_track_classes]
700
+ complete_track_classes_str = " | ".join(track_classes_upper)
701
+ return f"Complete the input track with {complete_track_classes_str}:"
702
+ else:
703
+ return "Complete the input track:"
704
+ else:
705
+ return "Fill the audio semantic mask based on the given conditions:"
706
 
707
  def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]:
708
  if audio_file is None: