File size: 15,698 Bytes
7a421a5
 
 
 
 
c096349
7a421a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c096349
7a421a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb50b30
7a421a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb50b30
 
7a421a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import copy
import os
import subprocess
import time
from typing import Dict, List, Optional, Tuple
import spaces
import gradio as gr
import soundfile as sf
import torch

from MuseControlLite_setup import initialize_condition_extractors, process_musical_conditions, setup_MuseControlLite
from config_inference import get_config

# Stable Audio uses fixed-length 47.5s chunks (2097152 / 44100)
TOTAL_AUDIO_SECONDS = 2097152 / 44100
DEFAULT_CONFIG = get_config()
DEFAULT_PROMPT = DEFAULT_CONFIG["text"][0] if DEFAULT_CONFIG.get("text") else ""
OUTPUT_ROOT = os.path.join(DEFAULT_CONFIG["output_dir"], "gradio_runs")
CONDITION_CHOICES = ["melody_stereo", "melody_mono", "dynamics", "rhythm", "audio"]
CHECKPOINT_EXPECTED = [
    "./checkpoints/woSDD-all/model_3.safetensors",
    "./checkpoints/woSDD-all/model_1.safetensors",
    "./checkpoints/woSDD-all/model_2.safetensors",
    "./checkpoints/woSDD-all/model.safetensors",
]

os.makedirs(OUTPUT_ROOT, exist_ok=True)


def ensure_checkpoints() -> None:
    """Download checkpoints with gdown if they are missing."""
    if all(os.path.exists(path) for path in CHECKPOINT_EXPECTED):
        return
    os.makedirs("checkpoints", exist_ok=True)
    try:
        subprocess.run(
            ["gdown", "1Q9B333jcq1czA11JKTbM-DHANJ8YqGbP", "--folder"],
            check=True,
        )
    except Exception as exc:  # pylint: disable=broad-except
        # Do not crash the space on startup; inference will surface an error later if checkpoints are missing.
        print(f"[warn] Checkpoint download failed: {exc}")


ensure_checkpoints()


class ModelCache:
    """Lazy loader for heavy pipelines and condition extractors."""

    def __init__(self) -> None:
        self.cache: Dict[Tuple, Dict] = {}

    def get(self, config: Dict) -> Dict:
        key = (
            tuple(sorted(config["condition_type"])),
            config["weight_dtype"],
            float(config["ap_scale"]),
            config["apadapter"],
        )
        if key in self.cache:
            return self.cache[key]

        weight_dtype = torch.float16 if config["weight_dtype"] == "fp16" else torch.float32
        if config["apadapter"]:
            condition_extractors, transformer_ckpt = initialize_condition_extractors(config)
            pipe = setup_MuseControlLite(config, weight_dtype, transformer_ckpt).to("cuda")
            payload = {
                "pipe": pipe,
                "condition_extractors": condition_extractors,
                "weight_dtype": weight_dtype,
                "mode": "musecontrol",
            }
        else:
            from diffusers import StableAudioPipeline

            pipe = StableAudioPipeline.from_pretrained(
                "stabilityai/stable-audio-open-1.0",
                torch_dtype=weight_dtype,
            ).to("cuda")
            payload = {"pipe": pipe, "condition_extractors": None, "weight_dtype": weight_dtype, "mode": "vanilla"}
        self.cache[key] = payload
        return payload


model_cache = ModelCache()


def _build_base_config() -> Dict:
    return copy.deepcopy(DEFAULT_CONFIG)


def _create_run_dir() -> str:
    run_dir = os.path.join(OUTPUT_ROOT, f"run_{int(time.time() * 1000)}")
    os.makedirs(run_dir, exist_ok=True)
    return run_dir


def _seed_to_generator(seed: Optional[float]) -> Optional[torch.Generator]:
    if seed is None or seed == "":
        return None
    try:
        seed_int = int(seed)
    except (TypeError, ValueError):
        return None
    generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
    return generator.manual_seed(seed_int)


def _validate_condition_choices(condition_type: Optional[List[str]]) -> List[str]:
    condition_type = condition_type or []
    if "melody_stereo" in condition_type and any(
        choice in condition_type for choice in ("dynamics", "rhythm", "melody_mono")
    ):
        raise gr.Error("`melody_stereo` cannot be combined with dynamics, rhythm, or melody_mono.")
    return condition_type


@spaces.GPU
def run_inference(
    prompt_text: str,
    condition_audio: Optional[str],
    condition_type: Optional[List[str]],
    use_musecontrol: bool,
    no_text: bool,
    negative_text_prompt: str,
    guidance_scale_text: float,
    guidance_scale_con: float,
    guidance_scale_audio: float,
    denoise_step: int,
    weight_dtype: str,
    ap_scale: float,
    sigma_min: float,
    sigma_max: float,
    audio_mask_start: float,
    audio_mask_end: float,
    musical_mask_start: float,
    musical_mask_end: float,
    seed: Optional[float],
):    

    condition_type = _validate_condition_choices(condition_type)
    config = _build_base_config()
    config.update(
        {
            "text": [prompt_text or ""],
            "audio_files": [condition_audio or ""],
            "apadapter": use_musecontrol,
            "no_text": bool(no_text),
            "negative_text_prompt": negative_text_prompt or "",
            "guidance_scale_text": float(guidance_scale_text),
            "guidance_scale_con": float(guidance_scale_con),
            "guidance_scale_audio": float(guidance_scale_audio),
            "denoise_step": int(denoise_step),
            "weight_dtype": weight_dtype,
            "ap_scale": float(ap_scale),
            "sigma_min": float(sigma_min),
            "sigma_max": float(sigma_max),
            "audio_mask_start_seconds": float(audio_mask_start or 0),
            "audio_mask_end_seconds": float(audio_mask_end or 0),
            "musical_attribute_mask_start_seconds": float(musical_mask_start or 0),
            "musical_attribute_mask_end_seconds": float(musical_mask_end or 0),
            "show_result_and_plt": False,
        }
    )
    config["condition_type"] = condition_type
    if config["apadapter"]:
        if not condition_type:
            raise gr.Error("Select at least one condition type when using MuseControlLite.")
        if not condition_audio:
            raise gr.Error("Upload an audio file for conditioning.")
        if not os.path.exists(condition_audio):
            raise gr.Error("Condition audio file not found.")

    run_dir = _create_run_dir()
    config["output_dir"] = run_dir
    generator = _seed_to_generator(seed)

    try:
        models = model_cache.get(config)
        pipe = models["pipe"].to("cuda")
        pipe.enable_attention_slicing()
        pipe.scheduler.config.sigma_min = config["sigma_min"]
        pipe.scheduler.config.sigma_max = config["sigma_max"]
        prompt_for_model = "" if config["no_text"] else (prompt_text or "")

        with torch.no_grad():
            if config["apadapter"]:
                final_condition, final_condition_audio = process_musical_conditions(
                    config, condition_audio, models["condition_extractors"], run_dir, 0, models["weight_dtype"], pipe
                )
                waveform = pipe(
                    extracted_condition=final_condition,
                    extracted_condition_audio=final_condition_audio,
                    prompt=prompt_for_model,
                    negative_prompt=config["negative_text_prompt"],
                    num_inference_steps=config["denoise_step"],
                    guidance_scale_text=config["guidance_scale_text"],
                    guidance_scale_con=config["guidance_scale_con"],
                    guidance_scale_audio=config["guidance_scale_audio"],
                    num_waveforms_per_prompt=1,
                    audio_end_in_s=TOTAL_AUDIO_SECONDS,
                    generator=generator,
                ).audios
                output = waveform[0].T.float().cpu().numpy()
                sr = pipe.vae.sampling_rate
            else:
                audio = pipe(
                    prompt=prompt_for_model,
                    negative_prompt=config["negative_text_prompt"],
                    num_inference_steps=config["denoise_step"],
                    guidance_scale=config["guidance_scale_text"],
                    num_waveforms_per_prompt=1,
                    audio_end_in_s=TOTAL_AUDIO_SECONDS,
                    generator=generator,
                ).audios
                output = audio[0].T.float().cpu().numpy()
                sr = pipe.vae.sampling_rate

        generated_path = os.path.join(run_dir, "generated.wav")
        sf.write(generated_path, output, sr)

        status_lines = [
            f"Run directory: `{run_dir}`",
            f"Mode: {'MuseControlLite' if config['apadapter'] else 'Stable Audio base'}",
            f"Condition type: {', '.join(condition_type) if condition_type else 'text only'}",
            f"Dtype: {config['weight_dtype']}, steps: {config['denoise_step']}, sigma [{config['sigma_min']}, {config['sigma_max']}]",
        ]
        if config["apadapter"]:
            status_lines.append(
                f"Guidance (text/cond/audio): {config['guidance_scale_text']}/{config['guidance_scale_con']}/{config['guidance_scale_audio']}"
            )
        if generator is not None:
            status_lines.append(f"Seed: {int(seed)}")

        status_md = "\n".join(f"- {line}" for line in status_lines)
        return generated_path, status_md
    except gr.Error:
        raise
    except Exception as err:  # pylint: disable=broad-except
        raise gr.Error(f"Generation failed: {err}") from err


EXAMPLES = [
    [
        "Electronic music that has a constant melody throughout with accompanying instruments used to supplement the melody which can be heard in possibly a casual setting",
        "melody_condition_audio/49_piano.mp3",
        ["melody_stereo"],
        True,
        False,
        "",
        7.0,
        1.5,
        1.0,
        50,
        "fp16",
        1.0,
        0.3,
        500,
        0,
        0,
        0,
        0,
        42,
    ],
    [
        "fast and fun beat-based indie pop to set a protagonist-gets-good-at-x movie montage to.",
        "melody_condition_audio/610_bass.mp3",
        ["melody_mono", "dynamics", "rhythm"],
        True,
        False,
        "",
        7.0,
        1.5,
        1.0,
        50,
        "fp16",
        1.0,
        0.3,
        500,
        0,
        0,
        0,
        0,
        7,
    ],
]


def build_interface() -> gr.Blocks:
    with gr.Blocks(title="MuseControlLite") as demo:
        gr.Markdown(
            """
            ## MuseControlLite demo
            UI for MuseControlLite (47.5s generations). This Space downloads checkpoints on startup with gdown and expects a GPU runtime; duplicate to a GPU Space or run locally for actual generation.
            """
        )
        with gr.Row():
            prompt = gr.Textbox(label="Text prompt", lines=3, value=DEFAULT_PROMPT)
            use_musecontrol = gr.Checkbox(label="Use MuseControlLite adapters", value=True)
            no_text = gr.Checkbox(label="Ignore text prompt (audio-only guidance)", value=False)

        condition_audio = gr.Audio(
            label="Condition audio (required for MuseControlLite)", type="filepath", sources=["upload", "microphone"]
        )
        condition_type = gr.CheckboxGroup(
            CONDITION_CHOICES, label="Condition types", value=DEFAULT_CONFIG.get("condition_type", [])
        )

        with gr.Accordion("Advanced controls", open=False):
            negative_prompt = gr.Textbox(label="Negative prompt", lines=2, value=DEFAULT_CONFIG.get("negative_text_prompt", ""))
            with gr.Row():
                guidance_scale_text = gr.Slider(
                    minimum=0.0,
                    maximum=12.0,
                    value=DEFAULT_CONFIG["guidance_scale_text"],
                    step=0.1,
                    label="Guidance scale (text)",
                )
                guidance_scale_con = gr.Slider(
                    minimum=0.0,
                    maximum=5.0,
                    value=DEFAULT_CONFIG["guidance_scale_con"],
                    step=0.1,
                    label="Guidance scale (conditions)",
                )
                guidance_scale_audio = gr.Slider(
                    minimum=0.0,
                    maximum=5.0,
                    value=DEFAULT_CONFIG["guidance_scale_audio"],
                    step=0.1,
                    label="Guidance scale (audio)",
                )
            with gr.Row():
                denoise_step = gr.Slider(
                    minimum=10, maximum=100, value=DEFAULT_CONFIG["denoise_step"], step=1, label="Denoising steps"
                )
                weight_dtype = gr.Radio(["fp16", "fp32"], value=DEFAULT_CONFIG["weight_dtype"], label="Weight dtype")
                ap_scale = gr.Slider(
                    minimum=0.5, maximum=2.0, value=DEFAULT_CONFIG["ap_scale"], step=0.05, label="AP scale"
                )
            with gr.Row():
                sigma_min = gr.Slider(
                    minimum=0.1, maximum=5.0, value=DEFAULT_CONFIG["sigma_min"], step=0.05, label="Scheduler sigma min"
                )
                sigma_max = gr.Slider(
                    minimum=50, maximum=700, value=DEFAULT_CONFIG["sigma_max"], step=1, label="Scheduler sigma max"
                )
                seed = gr.Number(label="Seed (optional)", precision=0)
            with gr.Row():
                audio_mask_start = gr.Number(
                    label="Audio mask start (s)", value=DEFAULT_CONFIG["audio_mask_start_seconds"]
                )
                audio_mask_end = gr.Number(label="Audio mask end (s)", value=DEFAULT_CONFIG["audio_mask_end_seconds"])
            with gr.Row():
                musical_mask_start = gr.Number(
                    label="Musical attribute mask start (s)", value=DEFAULT_CONFIG["musical_attribute_mask_start_seconds"]
                )
                musical_mask_end = gr.Number(
                    label="Musical attribute mask end (s)", value=DEFAULT_CONFIG["musical_attribute_mask_end_seconds"]
                )

        generate_btn = gr.Button("Generate", variant="primary")
        generated_audio = gr.Audio(label="Generated audio", type="filepath")
        status = gr.Markdown(label="Run details")

        generate_btn.click(
            fn=run_inference,
            inputs=[
                prompt,
                condition_audio,
                condition_type,
                use_musecontrol,
                no_text,
                negative_prompt,
                guidance_scale_text,
                guidance_scale_con,
                guidance_scale_audio,
                denoise_step,
                weight_dtype,
                ap_scale,
                sigma_min,
                sigma_max,
                audio_mask_start,
                audio_mask_end,
                musical_mask_start,
                musical_mask_end,
                seed,
            ],
            outputs=[generated_audio, status],
        )

        gr.Examples(
            examples=EXAMPLES,
            inputs=[
                prompt,
                condition_audio,
                condition_type,
                use_musecontrol,
                no_text,
                negative_prompt,
                guidance_scale_text,
                guidance_scale_con,
                guidance_scale_audio,
                denoise_step,
                weight_dtype,
                ap_scale,
                sigma_min,
                sigma_max,
                audio_mask_start,
                audio_mask_end,
                musical_mask_start,
                musical_mask_end,
                seed,
            ],
            label="Quick start examples (click to populate the form)",
        )
    return demo


if __name__ == "__main__":
    demo = build_interface()
    demo.launch()