Yng314
refactor: move inline text styling to a dedicated CSS class
9a75ced
import logging
import os
import subprocess
from pathlib import Path
from typing import Optional, Tuple
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
from pipeline.transition_generator import (
PLUGIN_PRESETS,
TransitionRequest,
generate_transition_artifacts,
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
LOGGER = logging.getLogger(__name__)
LORA_DROPDOWN_CHOICES = [
"None",
"Chinese New Year (official)",
"Our Trained Guitar-Style LoRA",
]
LORA_REPO_MAP = {
"Chinese New Year (official)": "ACE-Step/ACE-Step-v1.5-chinese-new-year-LoRA",
"Our Trained Guitar-Style LoRA": "yng314/audio_generation_lora",
}
APP_CSS = """
.adv-item label,
.adv-item .gr-block-label,
.adv-item .gr-block-title {
white-space: nowrap !important;
overflow: hidden !important;
text-overflow: ellipsis !important;
}
.result-audio-label label,
.result-audio-label .gr-block-label,
.result-audio-label .gr-block-title {
white-space: pre-line !important;
}
.hero-generate-text {
color: #16a34a !important;
font-weight: 600;
}
#run-transition-btn,
#run-transition-btn button {
background: #16a34a !important;
background-image: none !important;
border-color: #16a34a !important;
color: #ffffff !important;
}
#run-transition-btn:hover,
#run-transition-btn button:hover {
background: #15803d !important;
background-image: none !important;
border-color: #15803d !important;
}
"""
APP_THEME = gr.themes.Soft(
primary_hue="blue",
neutral_hue="slate",
radius_size="lg",
).set(
block_radius="*radius_xl",
input_radius="*radius_xl",
button_large_radius="*radius_xl",
button_medium_radius="*radius_xl",
button_small_radius="*radius_xl",
)
FORCE_DARK_HEAD = """
<script>
(() => {
try {
const url = new URL(window.location.href);
if (url.searchParams.get("__theme") !== "dark") {
url.searchParams.set("__theme", "dark");
window.location.replace(url.toString());
return;
}
// Ensure dark class is present as early as possible.
document.documentElement.classList.add("dark");
} catch (err) {
// No-op: fail open if URL manipulation is unavailable.
}
})();
</script>
"""
DEFAULT_DEMO_REPO = os.getenv("AI_DJ_DEFAULT_DEMO_REPO", "yng314/audio-demo-private").strip()
DEFAULT_DEMO_SONG_A = os.getenv("AI_DJ_DEFAULT_DEMO_SONG_A", "song_a.mp3").strip() or "song_a.mp3"
DEFAULT_DEMO_SONG_B = os.getenv("AI_DJ_DEFAULT_DEMO_SONG_B", "song_b.mp3").strip() or "song_b.mp3"
def _env_flag(name: str, default: bool) -> bool:
raw = os.getenv(name, "1" if default else "0").strip().lower()
return raw not in {"0", "false", "no", "off"}
def _prefetch_demucs_weights() -> None:
# Pre-download Demucs checkpoint during startup to avoid first-request timeout on ZeroGPU.
if not _env_flag("AI_DJ_PREFETCH_DEMUCS", True):
return
model_name = os.getenv("AI_DJ_DEMUCS_MODEL", "htdemucs").strip() or "htdemucs"
try:
from demucs.pretrained import get_model # type: ignore
LOGGER.info("Prefetching Demucs model '%s'...", model_name)
get_model(model_name)
LOGGER.info("Demucs model '%s' prefetch complete.", model_name)
except Exception as exc:
LOGGER.warning("Demucs prefetch skipped/failed (%s).", exc)
def _to_optional_float(value) -> Optional[float]:
if value is None:
return None
if isinstance(value, str) and not value.strip():
return None
try:
return float(value)
except Exception:
return None
def _normalize_upload_for_ui(path: Optional[str]) -> Optional[str]:
if not path:
return path
src = str(path)
if not os.path.isfile(src):
return path
out_dir = os.path.join("outputs", "normalized_uploads")
os.makedirs(out_dir, exist_ok=True)
stem = Path(src).stem
dst = os.path.join(out_dir, f"{stem}_ui_norm.wav")
cmd = [
"ffmpeg",
"-hide_banner",
"-loglevel",
"error",
"-nostdin",
"-y",
"-i",
src,
"-vn",
"-ac",
"2",
"-ar",
"44100",
"-c:a",
"pcm_s16le",
dst,
]
try:
subprocess.run(cmd, check=True)
return dst
except Exception as exc:
LOGGER.warning("Upload normalization failed for %s (%s). Using original file.", src, exc)
return src
def _download_default_demo_song(repo_id: str, filename: str, token: Optional[str]) -> Optional[str]:
if not repo_id or not filename:
return None
try:
local_path = hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=filename,
token=token,
local_dir="outputs/default_inputs",
)
return _normalize_upload_for_ui(local_path)
except Exception as exc:
LOGGER.warning("Default demo song download failed for %s/%s (%s).", repo_id, filename, exc)
return None
def _resolve_default_demo_inputs() -> Tuple[Optional[str], Optional[str], str]:
if not _env_flag("AI_DJ_ENABLE_DEFAULT_DEMO", True):
return None, None, "Default demo songs disabled (AI_DJ_ENABLE_DEFAULT_DEMO=0)."
token = os.getenv("HF_TOKEN", "").strip() or None
if token is None:
return None, None, "Default demo songs not loaded: missing HF_TOKEN secret."
song_a_default = _download_default_demo_song(DEFAULT_DEMO_REPO, DEFAULT_DEMO_SONG_A, token)
song_b_default = _download_default_demo_song(DEFAULT_DEMO_REPO, DEFAULT_DEMO_SONG_B, token)
if song_a_default and song_b_default:
return song_a_default, song_b_default, (
f"Default demo songs loaded from `{DEFAULT_DEMO_REPO}` "
f"(`{DEFAULT_DEMO_SONG_A}`, `{DEFAULT_DEMO_SONG_B}`)."
)
return None, None, (
f"Default demo songs not loaded from `{DEFAULT_DEMO_REPO}`; "
"please upload Song A and Song B manually."
)
@spaces.GPU(duration=120)
def _run_transition(
song_a,
song_b,
plugin_id,
instruction_text,
transition_bars,
pre_context_sec,
post_context_sec,
analysis_sec,
bpm_target,
creativity_strength,
inference_steps,
seed,
cue_a_sec,
cue_b_sec,
lora_choice,
lora_scale,
output_dir,
):
if not song_a or not song_b:
raise gr.Error("Please upload both Song A and Song B.")
selected_lora_path = LORA_REPO_MAP.get(str(lora_choice), "")
output_root = (output_dir or "outputs").strip()
base_output_dir = os.path.join(output_root, "compare_no_lora")
lora_output_dir = os.path.join(output_root, "compare_lora")
base_request = TransitionRequest(
song_a_path=song_a,
song_b_path=song_b,
plugin_id=plugin_id,
instruction_text=instruction_text or "",
transition_base_mode="B-base-fixed",
transition_bars=int(transition_bars),
pre_context_sec=float(pre_context_sec),
repaint_width_sec=4.0,
post_context_sec=float(post_context_sec),
analysis_sec=float(analysis_sec),
bpm_target=_to_optional_float(bpm_target),
cue_a_sec=_to_optional_float(cue_a_sec),
cue_b_sec=_to_optional_float(cue_b_sec),
creativity_strength=float(creativity_strength),
inference_steps=int(inference_steps),
seed=int(seed),
acestep_lora_path="",
acestep_lora_scale=float(lora_scale),
output_dir=base_output_dir,
)
try:
baseline = generate_transition_artifacts(base_request)
except Exception as exc:
raise gr.Error(str(exc))
lora_transition = None
lora_hard_splice = None
lora_rough_stitched = None
lora_stitched = None
if selected_lora_path:
lora_request = TransitionRequest(
song_a_path=song_a,
song_b_path=song_b,
plugin_id=plugin_id,
instruction_text=instruction_text or "",
transition_base_mode="B-base-fixed",
transition_bars=int(transition_bars),
pre_context_sec=float(pre_context_sec),
repaint_width_sec=4.0,
post_context_sec=float(post_context_sec),
analysis_sec=float(analysis_sec),
bpm_target=_to_optional_float(bpm_target),
cue_a_sec=_to_optional_float(cue_a_sec),
cue_b_sec=_to_optional_float(cue_b_sec),
creativity_strength=float(creativity_strength),
inference_steps=int(inference_steps),
seed=int(seed),
acestep_lora_path=selected_lora_path,
acestep_lora_scale=float(lora_scale),
output_dir=lora_output_dir,
)
try:
lora_result = generate_transition_artifacts(lora_request)
lora_transition = lora_result.transition_path
lora_hard_splice = lora_result.hard_splice_path
lora_rough_stitched = lora_result.rough_stitched_path
lora_stitched = lora_result.stitched_path
except Exception as exc:
raise gr.Error(f"Baseline generated, but LoRA variant failed: {exc}")
return (
baseline.transition_path,
baseline.hard_splice_path,
baseline.rough_stitched_path,
baseline.stitched_path,
lora_transition,
lora_hard_splice,
lora_rough_stitched,
lora_stitched,
)
def build_ui() -> gr.Blocks:
default_song_a, default_song_b, default_demo_status = _resolve_default_demo_inputs()
with gr.Blocks(theme=APP_THEME, css=APP_CSS) as demo:
gr.HTML(
"""
<div style="text-align:center;">
<h1>AI DJ Transition Generator</h1>
<p>Upload two songs and generate a smooth transition between them. For best results, please use default demo songs and parameters (just simply click the button "<span class="hero-generate-text">Generate transition artifacts</span>").</p>
</div>
""".strip()
)
with gr.Row():
gr.Markdown(
"""
### How to use
1. Upload **Song A** (current track) and **Song B** (next track). For demonstartion, there are two default songs.
2. Choose a **Transition style plugin**, this will control the style of the transition.
3. Optionally add **Text instruction** (e.g., smooth, rising energy, no vocals).
4. Select **LoRA adapter**, this will control the style of the transition. For demonstartion, there is one default LoRA adapter "Our Trained Guitar-Style LoRA", which is trained on guitar-style music by ourselves.
5. Click **Generate transition artifacts**.
""".strip(),
container=False,
elem_classes=["plain-info"],
)
gr.Markdown(
"""
### Outputs (If LoRA is selected, there will be results in the LoRA Variant section)
- **Generated transition clip**: AI-generated repaint transition segment.
- **Hard splice baseline (no transition)**: direct cut baseline.
- **No-repaint rough stitch**: stitched baseline without repaint.
- **Final stitched clip**: final result with transition inserted.
""".strip(),
container=False,
elem_classes=["plain-info"],
)
gr.Markdown(default_demo_status, elem_classes=["plain-info"])
with gr.Row():
song_a = gr.Audio(
label="Song A (mix out)",
type="filepath",
sources=["upload"],
value=default_song_a,
)
song_b = gr.Audio(
label="Song B (mix in)",
type="filepath",
sources=["upload"],
value=default_song_b,
)
song_a.upload(
fn=_normalize_upload_for_ui,
inputs=song_a,
outputs=song_a,
queue=False,
)
song_b.upload(
fn=_normalize_upload_for_ui,
inputs=song_b,
outputs=song_b,
queue=False,
)
with gr.Row():
with gr.Column():
plugin_id = gr.Dropdown(
label="Transition style plugin",
choices=list(PLUGIN_PRESETS.keys()),
value="Smooth Blend",
info="Select the transition style profile used to guide repaint generation.",
)
with gr.Column():
lora_choice = gr.Dropdown(
label="LoRA adapter",
choices=LORA_DROPDOWN_CHOICES,
value="Our Trained Guitar-Style LoRA",
info="Select an ACE-Step LoRA adapter to apply during repaint.",
)
lora_scale = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.2,
step=0.05,
label="LoRA scale",
)
with gr.Column():
instruction_text = gr.Textbox(
label="Text instruction",
placeholder="e.g., smooth, rising energy, no vocals",
lines=2,
info="Optional extra prompt to refine transition mood, texture, and arrangement.",
)
with gr.Accordion("Advanced controls", open=False):
with gr.Row():
transition_bars = gr.Dropdown(
label="Transition period length (bars)",
choices=[4, 8, 16],
value=8,
info="Controls transition duration. Pipeline uses fixed B-base strategy with A as reference.",
min_width=320,
elem_classes=["adv-item"],
)
pre_context_sec = gr.Slider(
minimum=1,
maximum=12,
value=12,
step=0.5,
label="Seconds before seam (Song A context)",
info="How much Song A context is included before the repaint region.",
min_width=320,
elem_classes=["adv-item"],
)
post_context_sec = gr.Slider(
minimum=1,
maximum=12,
value=12,
step=0.5,
label="Seconds after seam (Song B context)",
info="How much Song B context is included after the repaint region.",
min_width=320,
elem_classes=["adv-item"],
)
with gr.Row():
analysis_sec = gr.Slider(
minimum=10,
maximum=90,
value=90,
step=5,
label="Analysis window (seconds)",
info="Length of each track window used for BPM/cue analysis and alignment.",
min_width=320,
elem_classes=["adv-item"],
)
bpm_target = gr.Number(
label="Optional BPM target override",
value=None,
info="Force Song A reference BPM for alignment when auto BPM is not desired.",
min_width=320,
elem_classes=["adv-item"],
)
with gr.Row():
creativity_strength = gr.Slider(
minimum=1.0,
maximum=12.0,
value=12.0,
step=0.5,
label="Creativity strength (guidance)",
info="Higher values push stronger prompt/style guidance in repaint generation.",
min_width=320,
elem_classes=["adv-item"],
)
inference_steps = gr.Slider(
minimum=1,
maximum=64,
value=64,
step=1,
label="ACE-Step inference steps",
info="More steps usually improve detail/stability but increase runtime.",
min_width=320,
elem_classes=["adv-item"],
)
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
precision=0,
info="Random seed for reproducibility; use the same value to repeat a run.",
min_width=320,
elem_classes=["adv-item"],
)
cue_a_sec = gr.Textbox(
label="Optional cue A override (sec)",
value="",
placeholder="Leave blank for auto cue selection",
info="Manually set Song A cue point in seconds; blank uses automatic selection.",
min_width=320,
elem_classes=["adv-item"],
)
with gr.Row():
cue_b_sec = gr.Textbox(
label="Optional cue B override (sec)",
value="",
placeholder="Leave blank for auto cue selection",
info="Manually set Song B cue point in seconds; blank uses automatic selection.",
min_width=320,
elem_classes=["adv-item"],
)
output_dir = gr.Textbox(
label="Output directory",
value="outputs",
info="Folder where generated transition artifacts will be saved.",
min_width=320,
elem_classes=["adv-item"],
)
run_btn = gr.Button("Generate transition artifacts", variant="primary", elem_id="run-transition-btn")
gr.Markdown("### Baseline (No LoRA)")
with gr.Row():
transition_audio = gr.Audio(
label="Generated transition clip\n(No LoRA)",
type="filepath",
elem_classes=["result-audio-label"],
)
hard_splice_audio = gr.Audio(
label="Hard splice baseline\n(No LoRA)",
type="filepath",
elem_classes=["result-audio-label"],
)
rough_stitched_audio = gr.Audio(
label="No-repaint rough stitch\n(No LoRA)",
type="filepath",
elem_classes=["result-audio-label"],
)
stitched_audio = gr.Audio(
label="Final stitched clip\n(No LoRA)",
type="filepath",
elem_classes=["result-audio-label"],
)
gr.Markdown("### LoRA Variant (generated only when LoRA adapter is selected)")
with gr.Row():
lora_transition_audio = gr.Audio(
label="Generated transition clip\n(LoRA)",
type="filepath",
elem_classes=["result-audio-label"],
)
lora_hard_splice_audio = gr.Audio(
label="Hard splice baseline\n(LoRA)",
type="filepath",
elem_classes=["result-audio-label"],
)
lora_rough_stitched_audio = gr.Audio(
label="No-repaint rough stitch\n(LoRA)",
type="filepath",
elem_classes=["result-audio-label"],
)
lora_stitched_audio = gr.Audio(
label="Final stitched clip\n(LoRA)",
type="filepath",
elem_classes=["result-audio-label"],
)
run_btn.click(
fn=_run_transition,
inputs=[
song_a,
song_b,
plugin_id,
instruction_text,
transition_bars,
pre_context_sec,
post_context_sec,
analysis_sec,
bpm_target,
creativity_strength,
inference_steps,
seed,
cue_a_sec,
cue_b_sec,
lora_choice,
lora_scale,
output_dir,
],
outputs=[
transition_audio,
hard_splice_audio,
rough_stitched_audio,
stitched_audio,
lora_transition_audio,
lora_hard_splice_audio,
lora_rough_stitched_audio,
lora_stitched_audio,
],
)
return demo
_prefetch_demucs_weights()
demo = build_ui()
if __name__ == "__main__":
demo.launch(
server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.getenv("GRADIO_SERVER_PORT", "7860")),
head=FORCE_DARK_HEAD,
footer_links=["api", "gradio"],
)