import logging import os import subprocess from pathlib import Path from typing import Optional, Tuple import gradio as gr import spaces from huggingface_hub import hf_hub_download from pipeline.transition_generator import ( PLUGIN_PRESETS, TransitionRequest, generate_transition_artifacts, ) logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", ) LOGGER = logging.getLogger(__name__) LORA_DROPDOWN_CHOICES = [ "None", "Chinese New Year (official)", "Our Trained Guitar-Style LoRA", ] LORA_REPO_MAP = { "Chinese New Year (official)": "ACE-Step/ACE-Step-v1.5-chinese-new-year-LoRA", "Our Trained Guitar-Style LoRA": "yng314/audio_generation_lora", } APP_CSS = """ .adv-item label, .adv-item .gr-block-label, .adv-item .gr-block-title { white-space: nowrap !important; overflow: hidden !important; text-overflow: ellipsis !important; } .result-audio-label label, .result-audio-label .gr-block-label, .result-audio-label .gr-block-title { white-space: pre-line !important; } .hero-generate-text { color: #16a34a !important; font-weight: 600; } #run-transition-btn, #run-transition-btn button { background: #16a34a !important; background-image: none !important; border-color: #16a34a !important; color: #ffffff !important; } #run-transition-btn:hover, #run-transition-btn button:hover { background: #15803d !important; background-image: none !important; border-color: #15803d !important; } """ APP_THEME = gr.themes.Soft( primary_hue="blue", neutral_hue="slate", radius_size="lg", ).set( block_radius="*radius_xl", input_radius="*radius_xl", button_large_radius="*radius_xl", button_medium_radius="*radius_xl", button_small_radius="*radius_xl", ) FORCE_DARK_HEAD = """ """ DEFAULT_DEMO_REPO = os.getenv("AI_DJ_DEFAULT_DEMO_REPO", "yng314/audio-demo-private").strip() DEFAULT_DEMO_SONG_A = os.getenv("AI_DJ_DEFAULT_DEMO_SONG_A", "song_a.mp3").strip() or "song_a.mp3" DEFAULT_DEMO_SONG_B = os.getenv("AI_DJ_DEFAULT_DEMO_SONG_B", "song_b.mp3").strip() or "song_b.mp3" def _env_flag(name: str, default: bool) -> bool: raw = os.getenv(name, "1" if default else "0").strip().lower() return raw not in {"0", "false", "no", "off"} def _prefetch_demucs_weights() -> None: # Pre-download Demucs checkpoint during startup to avoid first-request timeout on ZeroGPU. if not _env_flag("AI_DJ_PREFETCH_DEMUCS", True): return model_name = os.getenv("AI_DJ_DEMUCS_MODEL", "htdemucs").strip() or "htdemucs" try: from demucs.pretrained import get_model # type: ignore LOGGER.info("Prefetching Demucs model '%s'...", model_name) get_model(model_name) LOGGER.info("Demucs model '%s' prefetch complete.", model_name) except Exception as exc: LOGGER.warning("Demucs prefetch skipped/failed (%s).", exc) def _to_optional_float(value) -> Optional[float]: if value is None: return None if isinstance(value, str) and not value.strip(): return None try: return float(value) except Exception: return None def _normalize_upload_for_ui(path: Optional[str]) -> Optional[str]: if not path: return path src = str(path) if not os.path.isfile(src): return path out_dir = os.path.join("outputs", "normalized_uploads") os.makedirs(out_dir, exist_ok=True) stem = Path(src).stem dst = os.path.join(out_dir, f"{stem}_ui_norm.wav") cmd = [ "ffmpeg", "-hide_banner", "-loglevel", "error", "-nostdin", "-y", "-i", src, "-vn", "-ac", "2", "-ar", "44100", "-c:a", "pcm_s16le", dst, ] try: subprocess.run(cmd, check=True) return dst except Exception as exc: LOGGER.warning("Upload normalization failed for %s (%s). Using original file.", src, exc) return src def _download_default_demo_song(repo_id: str, filename: str, token: Optional[str]) -> Optional[str]: if not repo_id or not filename: return None try: local_path = hf_hub_download( repo_id=repo_id, repo_type="dataset", filename=filename, token=token, local_dir="outputs/default_inputs", ) return _normalize_upload_for_ui(local_path) except Exception as exc: LOGGER.warning("Default demo song download failed for %s/%s (%s).", repo_id, filename, exc) return None def _resolve_default_demo_inputs() -> Tuple[Optional[str], Optional[str], str]: if not _env_flag("AI_DJ_ENABLE_DEFAULT_DEMO", True): return None, None, "Default demo songs disabled (AI_DJ_ENABLE_DEFAULT_DEMO=0)." token = os.getenv("HF_TOKEN", "").strip() or None if token is None: return None, None, "Default demo songs not loaded: missing HF_TOKEN secret." song_a_default = _download_default_demo_song(DEFAULT_DEMO_REPO, DEFAULT_DEMO_SONG_A, token) song_b_default = _download_default_demo_song(DEFAULT_DEMO_REPO, DEFAULT_DEMO_SONG_B, token) if song_a_default and song_b_default: return song_a_default, song_b_default, ( f"Default demo songs loaded from `{DEFAULT_DEMO_REPO}` " f"(`{DEFAULT_DEMO_SONG_A}`, `{DEFAULT_DEMO_SONG_B}`)." ) return None, None, ( f"Default demo songs not loaded from `{DEFAULT_DEMO_REPO}`; " "please upload Song A and Song B manually." ) @spaces.GPU(duration=120) def _run_transition( song_a, song_b, plugin_id, instruction_text, transition_bars, pre_context_sec, post_context_sec, analysis_sec, bpm_target, creativity_strength, inference_steps, seed, cue_a_sec, cue_b_sec, lora_choice, lora_scale, output_dir, ): if not song_a or not song_b: raise gr.Error("Please upload both Song A and Song B.") selected_lora_path = LORA_REPO_MAP.get(str(lora_choice), "") output_root = (output_dir or "outputs").strip() base_output_dir = os.path.join(output_root, "compare_no_lora") lora_output_dir = os.path.join(output_root, "compare_lora") base_request = TransitionRequest( song_a_path=song_a, song_b_path=song_b, plugin_id=plugin_id, instruction_text=instruction_text or "", transition_base_mode="B-base-fixed", transition_bars=int(transition_bars), pre_context_sec=float(pre_context_sec), repaint_width_sec=4.0, post_context_sec=float(post_context_sec), analysis_sec=float(analysis_sec), bpm_target=_to_optional_float(bpm_target), cue_a_sec=_to_optional_float(cue_a_sec), cue_b_sec=_to_optional_float(cue_b_sec), creativity_strength=float(creativity_strength), inference_steps=int(inference_steps), seed=int(seed), acestep_lora_path="", acestep_lora_scale=float(lora_scale), output_dir=base_output_dir, ) try: baseline = generate_transition_artifacts(base_request) except Exception as exc: raise gr.Error(str(exc)) lora_transition = None lora_hard_splice = None lora_rough_stitched = None lora_stitched = None if selected_lora_path: lora_request = TransitionRequest( song_a_path=song_a, song_b_path=song_b, plugin_id=plugin_id, instruction_text=instruction_text or "", transition_base_mode="B-base-fixed", transition_bars=int(transition_bars), pre_context_sec=float(pre_context_sec), repaint_width_sec=4.0, post_context_sec=float(post_context_sec), analysis_sec=float(analysis_sec), bpm_target=_to_optional_float(bpm_target), cue_a_sec=_to_optional_float(cue_a_sec), cue_b_sec=_to_optional_float(cue_b_sec), creativity_strength=float(creativity_strength), inference_steps=int(inference_steps), seed=int(seed), acestep_lora_path=selected_lora_path, acestep_lora_scale=float(lora_scale), output_dir=lora_output_dir, ) try: lora_result = generate_transition_artifacts(lora_request) lora_transition = lora_result.transition_path lora_hard_splice = lora_result.hard_splice_path lora_rough_stitched = lora_result.rough_stitched_path lora_stitched = lora_result.stitched_path except Exception as exc: raise gr.Error(f"Baseline generated, but LoRA variant failed: {exc}") return ( baseline.transition_path, baseline.hard_splice_path, baseline.rough_stitched_path, baseline.stitched_path, lora_transition, lora_hard_splice, lora_rough_stitched, lora_stitched, ) def build_ui() -> gr.Blocks: default_song_a, default_song_b, default_demo_status = _resolve_default_demo_inputs() with gr.Blocks(theme=APP_THEME, css=APP_CSS) as demo: gr.HTML( """

AI DJ Transition Generator

Upload two songs and generate a smooth transition between them. For best results, please use default demo songs and parameters (just simply click the button "Generate transition artifacts").

""".strip() ) with gr.Row(): gr.Markdown( """ ### How to use 1. Upload **Song A** (current track) and **Song B** (next track). For demonstartion, there are two default songs. 2. Choose a **Transition style plugin**, this will control the style of the transition. 3. Optionally add **Text instruction** (e.g., smooth, rising energy, no vocals). 4. Select **LoRA adapter**, this will control the style of the transition. For demonstartion, there is one default LoRA adapter "Our Trained Guitar-Style LoRA", which is trained on guitar-style music by ourselves. 5. Click **Generate transition artifacts**. """.strip(), container=False, elem_classes=["plain-info"], ) gr.Markdown( """ ### Outputs (If LoRA is selected, there will be results in the LoRA Variant section) - **Generated transition clip**: AI-generated repaint transition segment. - **Hard splice baseline (no transition)**: direct cut baseline. - **No-repaint rough stitch**: stitched baseline without repaint. - **Final stitched clip**: final result with transition inserted. """.strip(), container=False, elem_classes=["plain-info"], ) gr.Markdown(default_demo_status, elem_classes=["plain-info"]) with gr.Row(): song_a = gr.Audio( label="Song A (mix out)", type="filepath", sources=["upload"], value=default_song_a, ) song_b = gr.Audio( label="Song B (mix in)", type="filepath", sources=["upload"], value=default_song_b, ) song_a.upload( fn=_normalize_upload_for_ui, inputs=song_a, outputs=song_a, queue=False, ) song_b.upload( fn=_normalize_upload_for_ui, inputs=song_b, outputs=song_b, queue=False, ) with gr.Row(): with gr.Column(): plugin_id = gr.Dropdown( label="Transition style plugin", choices=list(PLUGIN_PRESETS.keys()), value="Smooth Blend", info="Select the transition style profile used to guide repaint generation.", ) with gr.Column(): lora_choice = gr.Dropdown( label="LoRA adapter", choices=LORA_DROPDOWN_CHOICES, value="Our Trained Guitar-Style LoRA", info="Select an ACE-Step LoRA adapter to apply during repaint.", ) lora_scale = gr.Slider( minimum=0.0, maximum=2.0, value=1.2, step=0.05, label="LoRA scale", ) with gr.Column(): instruction_text = gr.Textbox( label="Text instruction", placeholder="e.g., smooth, rising energy, no vocals", lines=2, info="Optional extra prompt to refine transition mood, texture, and arrangement.", ) with gr.Accordion("Advanced controls", open=False): with gr.Row(): transition_bars = gr.Dropdown( label="Transition period length (bars)", choices=[4, 8, 16], value=8, info="Controls transition duration. Pipeline uses fixed B-base strategy with A as reference.", min_width=320, elem_classes=["adv-item"], ) pre_context_sec = gr.Slider( minimum=1, maximum=12, value=12, step=0.5, label="Seconds before seam (Song A context)", info="How much Song A context is included before the repaint region.", min_width=320, elem_classes=["adv-item"], ) post_context_sec = gr.Slider( minimum=1, maximum=12, value=12, step=0.5, label="Seconds after seam (Song B context)", info="How much Song B context is included after the repaint region.", min_width=320, elem_classes=["adv-item"], ) with gr.Row(): analysis_sec = gr.Slider( minimum=10, maximum=90, value=90, step=5, label="Analysis window (seconds)", info="Length of each track window used for BPM/cue analysis and alignment.", min_width=320, elem_classes=["adv-item"], ) bpm_target = gr.Number( label="Optional BPM target override", value=None, info="Force Song A reference BPM for alignment when auto BPM is not desired.", min_width=320, elem_classes=["adv-item"], ) with gr.Row(): creativity_strength = gr.Slider( minimum=1.0, maximum=12.0, value=12.0, step=0.5, label="Creativity strength (guidance)", info="Higher values push stronger prompt/style guidance in repaint generation.", min_width=320, elem_classes=["adv-item"], ) inference_steps = gr.Slider( minimum=1, maximum=64, value=64, step=1, label="ACE-Step inference steps", info="More steps usually improve detail/stability but increase runtime.", min_width=320, elem_classes=["adv-item"], ) with gr.Row(): seed = gr.Number( label="Seed", value=42, precision=0, info="Random seed for reproducibility; use the same value to repeat a run.", min_width=320, elem_classes=["adv-item"], ) cue_a_sec = gr.Textbox( label="Optional cue A override (sec)", value="", placeholder="Leave blank for auto cue selection", info="Manually set Song A cue point in seconds; blank uses automatic selection.", min_width=320, elem_classes=["adv-item"], ) with gr.Row(): cue_b_sec = gr.Textbox( label="Optional cue B override (sec)", value="", placeholder="Leave blank for auto cue selection", info="Manually set Song B cue point in seconds; blank uses automatic selection.", min_width=320, elem_classes=["adv-item"], ) output_dir = gr.Textbox( label="Output directory", value="outputs", info="Folder where generated transition artifacts will be saved.", min_width=320, elem_classes=["adv-item"], ) run_btn = gr.Button("Generate transition artifacts", variant="primary", elem_id="run-transition-btn") gr.Markdown("### Baseline (No LoRA)") with gr.Row(): transition_audio = gr.Audio( label="Generated transition clip\n(No LoRA)", type="filepath", elem_classes=["result-audio-label"], ) hard_splice_audio = gr.Audio( label="Hard splice baseline\n(No LoRA)", type="filepath", elem_classes=["result-audio-label"], ) rough_stitched_audio = gr.Audio( label="No-repaint rough stitch\n(No LoRA)", type="filepath", elem_classes=["result-audio-label"], ) stitched_audio = gr.Audio( label="Final stitched clip\n(No LoRA)", type="filepath", elem_classes=["result-audio-label"], ) gr.Markdown("### LoRA Variant (generated only when LoRA adapter is selected)") with gr.Row(): lora_transition_audio = gr.Audio( label="Generated transition clip\n(LoRA)", type="filepath", elem_classes=["result-audio-label"], ) lora_hard_splice_audio = gr.Audio( label="Hard splice baseline\n(LoRA)", type="filepath", elem_classes=["result-audio-label"], ) lora_rough_stitched_audio = gr.Audio( label="No-repaint rough stitch\n(LoRA)", type="filepath", elem_classes=["result-audio-label"], ) lora_stitched_audio = gr.Audio( label="Final stitched clip\n(LoRA)", type="filepath", elem_classes=["result-audio-label"], ) run_btn.click( fn=_run_transition, inputs=[ song_a, song_b, plugin_id, instruction_text, transition_bars, pre_context_sec, post_context_sec, analysis_sec, bpm_target, creativity_strength, inference_steps, seed, cue_a_sec, cue_b_sec, lora_choice, lora_scale, output_dir, ], outputs=[ transition_audio, hard_splice_audio, rough_stitched_audio, stitched_audio, lora_transition_audio, lora_hard_splice_audio, lora_rough_stitched_audio, lora_stitched_audio, ], ) return demo _prefetch_demucs_weights() demo = build_ui() if __name__ == "__main__": demo.launch( server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")), head=FORCE_DARK_HEAD, footer_links=["api", "gradio"], )