Spaces:
Running on Zero
Running on Zero
| import logging | |
| import os | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| from pipeline.transition_generator import ( | |
| PLUGIN_PRESETS, | |
| TransitionRequest, | |
| generate_transition_artifacts, | |
| ) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
| ) | |
| LOGGER = logging.getLogger(__name__) | |
| LORA_DROPDOWN_CHOICES = [ | |
| "None", | |
| "Chinese New Year (official)", | |
| "Our Trained Guitar-Style LoRA", | |
| ] | |
| LORA_REPO_MAP = { | |
| "Chinese New Year (official)": "ACE-Step/ACE-Step-v1.5-chinese-new-year-LoRA", | |
| "Our Trained Guitar-Style LoRA": "yng314/audio_generation_lora", | |
| } | |
| APP_CSS = """ | |
| .adv-item label, | |
| .adv-item .gr-block-label, | |
| .adv-item .gr-block-title { | |
| white-space: nowrap !important; | |
| overflow: hidden !important; | |
| text-overflow: ellipsis !important; | |
| } | |
| .result-audio-label label, | |
| .result-audio-label .gr-block-label, | |
| .result-audio-label .gr-block-title { | |
| white-space: pre-line !important; | |
| } | |
| .hero-generate-text { | |
| color: #16a34a !important; | |
| font-weight: 600; | |
| } | |
| #run-transition-btn, | |
| #run-transition-btn button { | |
| background: #16a34a !important; | |
| background-image: none !important; | |
| border-color: #16a34a !important; | |
| color: #ffffff !important; | |
| } | |
| #run-transition-btn:hover, | |
| #run-transition-btn button:hover { | |
| background: #15803d !important; | |
| background-image: none !important; | |
| border-color: #15803d !important; | |
| } | |
| """ | |
| APP_THEME = gr.themes.Soft( | |
| primary_hue="blue", | |
| neutral_hue="slate", | |
| radius_size="lg", | |
| ).set( | |
| block_radius="*radius_xl", | |
| input_radius="*radius_xl", | |
| button_large_radius="*radius_xl", | |
| button_medium_radius="*radius_xl", | |
| button_small_radius="*radius_xl", | |
| ) | |
| FORCE_DARK_HEAD = """ | |
| <script> | |
| (() => { | |
| try { | |
| const url = new URL(window.location.href); | |
| if (url.searchParams.get("__theme") !== "dark") { | |
| url.searchParams.set("__theme", "dark"); | |
| window.location.replace(url.toString()); | |
| return; | |
| } | |
| // Ensure dark class is present as early as possible. | |
| document.documentElement.classList.add("dark"); | |
| } catch (err) { | |
| // No-op: fail open if URL manipulation is unavailable. | |
| } | |
| })(); | |
| </script> | |
| """ | |
| DEFAULT_DEMO_REPO = os.getenv("AI_DJ_DEFAULT_DEMO_REPO", "yng314/audio-demo-private").strip() | |
| DEFAULT_DEMO_SONG_A = os.getenv("AI_DJ_DEFAULT_DEMO_SONG_A", "song_a.mp3").strip() or "song_a.mp3" | |
| DEFAULT_DEMO_SONG_B = os.getenv("AI_DJ_DEFAULT_DEMO_SONG_B", "song_b.mp3").strip() or "song_b.mp3" | |
| def _env_flag(name: str, default: bool) -> bool: | |
| raw = os.getenv(name, "1" if default else "0").strip().lower() | |
| return raw not in {"0", "false", "no", "off"} | |
| def _prefetch_demucs_weights() -> None: | |
| # Pre-download Demucs checkpoint during startup to avoid first-request timeout on ZeroGPU. | |
| if not _env_flag("AI_DJ_PREFETCH_DEMUCS", True): | |
| return | |
| model_name = os.getenv("AI_DJ_DEMUCS_MODEL", "htdemucs").strip() or "htdemucs" | |
| try: | |
| from demucs.pretrained import get_model # type: ignore | |
| LOGGER.info("Prefetching Demucs model '%s'...", model_name) | |
| get_model(model_name) | |
| LOGGER.info("Demucs model '%s' prefetch complete.", model_name) | |
| except Exception as exc: | |
| LOGGER.warning("Demucs prefetch skipped/failed (%s).", exc) | |
| def _to_optional_float(value) -> Optional[float]: | |
| if value is None: | |
| return None | |
| if isinstance(value, str) and not value.strip(): | |
| return None | |
| try: | |
| return float(value) | |
| except Exception: | |
| return None | |
| def _normalize_upload_for_ui(path: Optional[str]) -> Optional[str]: | |
| if not path: | |
| return path | |
| src = str(path) | |
| if not os.path.isfile(src): | |
| return path | |
| out_dir = os.path.join("outputs", "normalized_uploads") | |
| os.makedirs(out_dir, exist_ok=True) | |
| stem = Path(src).stem | |
| dst = os.path.join(out_dir, f"{stem}_ui_norm.wav") | |
| cmd = [ | |
| "ffmpeg", | |
| "-hide_banner", | |
| "-loglevel", | |
| "error", | |
| "-nostdin", | |
| "-y", | |
| "-i", | |
| src, | |
| "-vn", | |
| "-ac", | |
| "2", | |
| "-ar", | |
| "44100", | |
| "-c:a", | |
| "pcm_s16le", | |
| dst, | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True) | |
| return dst | |
| except Exception as exc: | |
| LOGGER.warning("Upload normalization failed for %s (%s). Using original file.", src, exc) | |
| return src | |
| def _download_default_demo_song(repo_id: str, filename: str, token: Optional[str]) -> Optional[str]: | |
| if not repo_id or not filename: | |
| return None | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=filename, | |
| token=token, | |
| local_dir="outputs/default_inputs", | |
| ) | |
| return _normalize_upload_for_ui(local_path) | |
| except Exception as exc: | |
| LOGGER.warning("Default demo song download failed for %s/%s (%s).", repo_id, filename, exc) | |
| return None | |
| def _resolve_default_demo_inputs() -> Tuple[Optional[str], Optional[str], str]: | |
| if not _env_flag("AI_DJ_ENABLE_DEFAULT_DEMO", True): | |
| return None, None, "Default demo songs disabled (AI_DJ_ENABLE_DEFAULT_DEMO=0)." | |
| token = os.getenv("HF_TOKEN", "").strip() or None | |
| if token is None: | |
| return None, None, "Default demo songs not loaded: missing HF_TOKEN secret." | |
| song_a_default = _download_default_demo_song(DEFAULT_DEMO_REPO, DEFAULT_DEMO_SONG_A, token) | |
| song_b_default = _download_default_demo_song(DEFAULT_DEMO_REPO, DEFAULT_DEMO_SONG_B, token) | |
| if song_a_default and song_b_default: | |
| return song_a_default, song_b_default, ( | |
| f"Default demo songs loaded from `{DEFAULT_DEMO_REPO}` " | |
| f"(`{DEFAULT_DEMO_SONG_A}`, `{DEFAULT_DEMO_SONG_B}`)." | |
| ) | |
| return None, None, ( | |
| f"Default demo songs not loaded from `{DEFAULT_DEMO_REPO}`; " | |
| "please upload Song A and Song B manually." | |
| ) | |
| def _run_transition( | |
| song_a, | |
| song_b, | |
| plugin_id, | |
| instruction_text, | |
| transition_bars, | |
| pre_context_sec, | |
| post_context_sec, | |
| analysis_sec, | |
| bpm_target, | |
| creativity_strength, | |
| inference_steps, | |
| seed, | |
| cue_a_sec, | |
| cue_b_sec, | |
| lora_choice, | |
| lora_scale, | |
| output_dir, | |
| ): | |
| if not song_a or not song_b: | |
| raise gr.Error("Please upload both Song A and Song B.") | |
| selected_lora_path = LORA_REPO_MAP.get(str(lora_choice), "") | |
| output_root = (output_dir or "outputs").strip() | |
| base_output_dir = os.path.join(output_root, "compare_no_lora") | |
| lora_output_dir = os.path.join(output_root, "compare_lora") | |
| base_request = TransitionRequest( | |
| song_a_path=song_a, | |
| song_b_path=song_b, | |
| plugin_id=plugin_id, | |
| instruction_text=instruction_text or "", | |
| transition_base_mode="B-base-fixed", | |
| transition_bars=int(transition_bars), | |
| pre_context_sec=float(pre_context_sec), | |
| repaint_width_sec=4.0, | |
| post_context_sec=float(post_context_sec), | |
| analysis_sec=float(analysis_sec), | |
| bpm_target=_to_optional_float(bpm_target), | |
| cue_a_sec=_to_optional_float(cue_a_sec), | |
| cue_b_sec=_to_optional_float(cue_b_sec), | |
| creativity_strength=float(creativity_strength), | |
| inference_steps=int(inference_steps), | |
| seed=int(seed), | |
| acestep_lora_path="", | |
| acestep_lora_scale=float(lora_scale), | |
| output_dir=base_output_dir, | |
| ) | |
| try: | |
| baseline = generate_transition_artifacts(base_request) | |
| except Exception as exc: | |
| raise gr.Error(str(exc)) | |
| lora_transition = None | |
| lora_hard_splice = None | |
| lora_rough_stitched = None | |
| lora_stitched = None | |
| if selected_lora_path: | |
| lora_request = TransitionRequest( | |
| song_a_path=song_a, | |
| song_b_path=song_b, | |
| plugin_id=plugin_id, | |
| instruction_text=instruction_text or "", | |
| transition_base_mode="B-base-fixed", | |
| transition_bars=int(transition_bars), | |
| pre_context_sec=float(pre_context_sec), | |
| repaint_width_sec=4.0, | |
| post_context_sec=float(post_context_sec), | |
| analysis_sec=float(analysis_sec), | |
| bpm_target=_to_optional_float(bpm_target), | |
| cue_a_sec=_to_optional_float(cue_a_sec), | |
| cue_b_sec=_to_optional_float(cue_b_sec), | |
| creativity_strength=float(creativity_strength), | |
| inference_steps=int(inference_steps), | |
| seed=int(seed), | |
| acestep_lora_path=selected_lora_path, | |
| acestep_lora_scale=float(lora_scale), | |
| output_dir=lora_output_dir, | |
| ) | |
| try: | |
| lora_result = generate_transition_artifacts(lora_request) | |
| lora_transition = lora_result.transition_path | |
| lora_hard_splice = lora_result.hard_splice_path | |
| lora_rough_stitched = lora_result.rough_stitched_path | |
| lora_stitched = lora_result.stitched_path | |
| except Exception as exc: | |
| raise gr.Error(f"Baseline generated, but LoRA variant failed: {exc}") | |
| return ( | |
| baseline.transition_path, | |
| baseline.hard_splice_path, | |
| baseline.rough_stitched_path, | |
| baseline.stitched_path, | |
| lora_transition, | |
| lora_hard_splice, | |
| lora_rough_stitched, | |
| lora_stitched, | |
| ) | |
| def build_ui() -> gr.Blocks: | |
| default_song_a, default_song_b, default_demo_status = _resolve_default_demo_inputs() | |
| with gr.Blocks(theme=APP_THEME, css=APP_CSS) as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align:center;"> | |
| <h1>AI DJ Transition Generator</h1> | |
| <p>Upload two songs and generate a smooth transition between them. For best results, please use default demo songs and parameters (just simply click the button "<span class="hero-generate-text">Generate transition artifacts</span>").</p> | |
| </div> | |
| """.strip() | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| ### How to use | |
| 1. Upload **Song A** (current track) and **Song B** (next track). For demonstartion, there are two default songs. | |
| 2. Choose a **Transition style plugin**, this will control the style of the transition. | |
| 3. Optionally add **Text instruction** (e.g., smooth, rising energy, no vocals). | |
| 4. Select **LoRA adapter**, this will control the style of the transition. For demonstartion, there is one default LoRA adapter "Our Trained Guitar-Style LoRA", which is trained on guitar-style music by ourselves. | |
| 5. Click **Generate transition artifacts**. | |
| """.strip(), | |
| container=False, | |
| elem_classes=["plain-info"], | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Outputs (If LoRA is selected, there will be results in the LoRA Variant section) | |
| - **Generated transition clip**: AI-generated repaint transition segment. | |
| - **Hard splice baseline (no transition)**: direct cut baseline. | |
| - **No-repaint rough stitch**: stitched baseline without repaint. | |
| - **Final stitched clip**: final result with transition inserted. | |
| """.strip(), | |
| container=False, | |
| elem_classes=["plain-info"], | |
| ) | |
| gr.Markdown(default_demo_status, elem_classes=["plain-info"]) | |
| with gr.Row(): | |
| song_a = gr.Audio( | |
| label="Song A (mix out)", | |
| type="filepath", | |
| sources=["upload"], | |
| value=default_song_a, | |
| ) | |
| song_b = gr.Audio( | |
| label="Song B (mix in)", | |
| type="filepath", | |
| sources=["upload"], | |
| value=default_song_b, | |
| ) | |
| song_a.upload( | |
| fn=_normalize_upload_for_ui, | |
| inputs=song_a, | |
| outputs=song_a, | |
| queue=False, | |
| ) | |
| song_b.upload( | |
| fn=_normalize_upload_for_ui, | |
| inputs=song_b, | |
| outputs=song_b, | |
| queue=False, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| plugin_id = gr.Dropdown( | |
| label="Transition style plugin", | |
| choices=list(PLUGIN_PRESETS.keys()), | |
| value="Smooth Blend", | |
| info="Select the transition style profile used to guide repaint generation.", | |
| ) | |
| with gr.Column(): | |
| lora_choice = gr.Dropdown( | |
| label="LoRA adapter", | |
| choices=LORA_DROPDOWN_CHOICES, | |
| value="Our Trained Guitar-Style LoRA", | |
| info="Select an ACE-Step LoRA adapter to apply during repaint.", | |
| ) | |
| lora_scale = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=1.2, | |
| step=0.05, | |
| label="LoRA scale", | |
| ) | |
| with gr.Column(): | |
| instruction_text = gr.Textbox( | |
| label="Text instruction", | |
| placeholder="e.g., smooth, rising energy, no vocals", | |
| lines=2, | |
| info="Optional extra prompt to refine transition mood, texture, and arrangement.", | |
| ) | |
| with gr.Accordion("Advanced controls", open=False): | |
| with gr.Row(): | |
| transition_bars = gr.Dropdown( | |
| label="Transition period length (bars)", | |
| choices=[4, 8, 16], | |
| value=8, | |
| info="Controls transition duration. Pipeline uses fixed B-base strategy with A as reference.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| pre_context_sec = gr.Slider( | |
| minimum=1, | |
| maximum=12, | |
| value=12, | |
| step=0.5, | |
| label="Seconds before seam (Song A context)", | |
| info="How much Song A context is included before the repaint region.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| post_context_sec = gr.Slider( | |
| minimum=1, | |
| maximum=12, | |
| value=12, | |
| step=0.5, | |
| label="Seconds after seam (Song B context)", | |
| info="How much Song B context is included after the repaint region.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| with gr.Row(): | |
| analysis_sec = gr.Slider( | |
| minimum=10, | |
| maximum=90, | |
| value=90, | |
| step=5, | |
| label="Analysis window (seconds)", | |
| info="Length of each track window used for BPM/cue analysis and alignment.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| bpm_target = gr.Number( | |
| label="Optional BPM target override", | |
| value=None, | |
| info="Force Song A reference BPM for alignment when auto BPM is not desired.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| with gr.Row(): | |
| creativity_strength = gr.Slider( | |
| minimum=1.0, | |
| maximum=12.0, | |
| value=12.0, | |
| step=0.5, | |
| label="Creativity strength (guidance)", | |
| info="Higher values push stronger prompt/style guidance in repaint generation.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| inference_steps = gr.Slider( | |
| minimum=1, | |
| maximum=64, | |
| value=64, | |
| step=1, | |
| label="ACE-Step inference steps", | |
| info="More steps usually improve detail/stability but increase runtime.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number( | |
| label="Seed", | |
| value=42, | |
| precision=0, | |
| info="Random seed for reproducibility; use the same value to repeat a run.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| cue_a_sec = gr.Textbox( | |
| label="Optional cue A override (sec)", | |
| value="", | |
| placeholder="Leave blank for auto cue selection", | |
| info="Manually set Song A cue point in seconds; blank uses automatic selection.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| with gr.Row(): | |
| cue_b_sec = gr.Textbox( | |
| label="Optional cue B override (sec)", | |
| value="", | |
| placeholder="Leave blank for auto cue selection", | |
| info="Manually set Song B cue point in seconds; blank uses automatic selection.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| output_dir = gr.Textbox( | |
| label="Output directory", | |
| value="outputs", | |
| info="Folder where generated transition artifacts will be saved.", | |
| min_width=320, | |
| elem_classes=["adv-item"], | |
| ) | |
| run_btn = gr.Button("Generate transition artifacts", variant="primary", elem_id="run-transition-btn") | |
| gr.Markdown("### Baseline (No LoRA)") | |
| with gr.Row(): | |
| transition_audio = gr.Audio( | |
| label="Generated transition clip\n(No LoRA)", | |
| type="filepath", | |
| elem_classes=["result-audio-label"], | |
| ) | |
| hard_splice_audio = gr.Audio( | |
| label="Hard splice baseline\n(No LoRA)", | |
| type="filepath", | |
| elem_classes=["result-audio-label"], | |
| ) | |
| rough_stitched_audio = gr.Audio( | |
| label="No-repaint rough stitch\n(No LoRA)", | |
| type="filepath", | |
| elem_classes=["result-audio-label"], | |
| ) | |
| stitched_audio = gr.Audio( | |
| label="Final stitched clip\n(No LoRA)", | |
| type="filepath", | |
| elem_classes=["result-audio-label"], | |
| ) | |
| gr.Markdown("### LoRA Variant (generated only when LoRA adapter is selected)") | |
| with gr.Row(): | |
| lora_transition_audio = gr.Audio( | |
| label="Generated transition clip\n(LoRA)", | |
| type="filepath", | |
| elem_classes=["result-audio-label"], | |
| ) | |
| lora_hard_splice_audio = gr.Audio( | |
| label="Hard splice baseline\n(LoRA)", | |
| type="filepath", | |
| elem_classes=["result-audio-label"], | |
| ) | |
| lora_rough_stitched_audio = gr.Audio( | |
| label="No-repaint rough stitch\n(LoRA)", | |
| type="filepath", | |
| elem_classes=["result-audio-label"], | |
| ) | |
| lora_stitched_audio = gr.Audio( | |
| label="Final stitched clip\n(LoRA)", | |
| type="filepath", | |
| elem_classes=["result-audio-label"], | |
| ) | |
| run_btn.click( | |
| fn=_run_transition, | |
| inputs=[ | |
| song_a, | |
| song_b, | |
| plugin_id, | |
| instruction_text, | |
| transition_bars, | |
| pre_context_sec, | |
| post_context_sec, | |
| analysis_sec, | |
| bpm_target, | |
| creativity_strength, | |
| inference_steps, | |
| seed, | |
| cue_a_sec, | |
| cue_b_sec, | |
| lora_choice, | |
| lora_scale, | |
| output_dir, | |
| ], | |
| outputs=[ | |
| transition_audio, | |
| hard_splice_audio, | |
| rough_stitched_audio, | |
| stitched_audio, | |
| lora_transition_audio, | |
| lora_hard_splice_audio, | |
| lora_rough_stitched_audio, | |
| lora_stitched_audio, | |
| ], | |
| ) | |
| return demo | |
| _prefetch_demucs_weights() | |
| demo = build_ui() | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"), | |
| server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")), | |
| head=FORCE_DARK_HEAD, | |
| footer_links=["api", "gradio"], | |
| ) | |